Skip to content

Commit

Permalink
Merge pull request #1 from reibitto/improve-balancing
Browse files Browse the repository at this point in the history
Improve balancing with JUnit reports support
  • Loading branch information
reibitto authored Nov 18, 2023
2 parents 47ae219 + bbf4900 commit 0ad08db
Show file tree
Hide file tree
Showing 15 changed files with 306 additions and 73 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ jobs:
publish:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2.3.4
- uses: actions/checkout@v3
with:
fetch-depth: 0
- uses: olafurpg/setup-scala@v13
- uses: olafurpg/setup-gpg@v3
- uses: actions/setup-java@v3
with:
distribution: temurin
java-version: 8
cache: sbt
- run: sbt ci-release
env:
PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }}
PGP_SECRET: ${{ secrets.PGP_SECRET }}
SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }}
SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }}
SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }}
16 changes: 10 additions & 6 deletions .github/workflows/scala.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@ on:
branches: [master, main]

jobs:
build:
jvm:
strategy:
fail-fast: false
matrix:
scala: [2.12.18]
java: [[email protected], [email protected]]
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up JDK 11
uses: actions/setup-java@v1
- name: Set up environment
uses: olafurpg/setup-scala@v10
with:
java-version: 11
java-version: ${{ matrix.java }}

- name: Run tests
run: sbt -Dscalac.unused.enabled=true fmtCheck test
run: sbt ++${{ matrix.scala}} fmtCheck test
50 changes: 38 additions & 12 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
version = "3.2.1"
version = 3.7.14

runner.dialect = scala213
runner.dialect = scala213source3

maxColumn = 120
align.preset = most
continuationIndent.defnSite = 2
align = none
maxColumn = 120 # For wide displays.
assumeStandardLibraryStripMargin = true
align.tokens = [{code = "=>", owner = "Case"}, {code = "<-"}]
align.allowOverflow = true
align.arrowEnumeratorGenerator = true
align.openParenCallSite = false
align.openParenDefnSite = false
newlines.topLevelStatementBlankLines = [
{blanks {before = 1}}
{regex = "^Import|^Term.ApplyInfix"}
]
newlines.alwaysBeforeElseAfterCurlyIf = false
indentOperator.topLevelOnly = true
docstrings.style = SpaceAsterisk
lineEndings = preserve
docstrings.wrapMaxColumn = 80

includeCurlyBraceInSelectChains = false
danglingParentheses.preset = true
spaces {
inImportCurlyBraces = true
}
optIn.annotationNewlines = true
includeNoParensInSelectChains = false

rewrite.rules = [SortModifiers, PreferCurlyFors, SortImports, RedundantBraces]
rewrite.scala3.convertToNewSyntax = true
rewrite.imports.groups = [
[".*"],
["java\\..*", "javax\\..*", "scala\\..*"]
]

project.excludeFilters = ["/target/"]

lineEndings = preserve

rewrite.rules = [SortImports, RedundantBraces]
fileOverride {
"glob:**/*.sbt" {
runner.dialect = sbt1
rewrite.scala3.convertToNewSyntax = false
}
"glob:**/project/*.scala" {
rewrite.scala3.convertToNewSyntax = false
}
}
24 changes: 14 additions & 10 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import sbtwelcome._
import sbtwelcome.*

inThisBuild(
List(
organization := "com.github.reibitto",
homepage := Some(url("https://github.com/reibitto/sbt-test-shards")),
licenses := List("Apache-2.0" -> url("https://www.apache.org/licenses/LICENSE-2.0")),
developers := List(
homepage := Some(url("https://github.com/reibitto/sbt-test-shards")),
licenses := List("Apache-2.0" -> url("https://www.apache.org/licenses/LICENSE-2.0")),
developers := List(
Developer("reibitto", "reibitto", "[email protected]", url("https://reibitto.github.io"))
)
)
)

lazy val root = (project in file(".")).settings(
name := "sbt-test-shards",
name := "sbt-test-shards",
organization := "com.github.reibitto",
scalaVersion := "2.12.16",
sbtPlugin := true
scalaVersion := "2.12.18",
sbtPlugin := true,
libraryDependencies ++= Seq(
"org.scalameta" %% "munit" % "0.7.29" % Test,
"org.scalameta" %% "munit-scalacheck" % "0.7.29" % Test
)
)

addCommandAlias("fmt", "all root/scalafmtSbt root/scalafmtAll")
Expand All @@ -41,9 +45,9 @@ logo :=
|""".stripMargin

usefulTasks := Seq(
UsefulTask("a", "~compile", "Compile with file-watch enabled"),
UsefulTask("b", "fmt", "Run scalafmt on the entire project"),
UsefulTask("c", "publishLocal", "Publish the sbt plugin locally so that you can consume it from a different project")
UsefulTask("~compile", "Compile with file-watch enabled"),
UsefulTask("fmt", "Run scalafmt on the entire project"),
UsefulTask("publishLocal", "Publish the sbt plugin locally so that you can consume it from a different project")
)

logoColor := scala.Console.MAGENTA
Expand Down
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=1.7.1
sbt.version=1.9.7
8 changes: 5 additions & 3 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6")
addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.10")
addSbtPlugin("com.github.reibitto" % "sbt-welcome" % "0.2.2")
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2")

addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.12")

addSbtPlugin("com.github.reibitto" % "sbt-welcome" % "0.4.0")
77 changes: 57 additions & 20 deletions src/main/scala/sbttestshards/ShardingAlgorithm.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package sbttestshards

import sbttestshards.parsers.JUnitReportParser

import java.nio.charset.StandardCharsets
import java.nio.file.Path
import java.time.Duration
import scala.util.hashing.MurmurHash3

// This trait is open so that users can implement a custom `ShardingAlgorithm` if they'd like
trait ShardingAlgorithm {
Expand All @@ -11,34 +16,60 @@ trait ShardingAlgorithm {

object ShardingAlgorithm {

/** Shards by suite the name. This is the most reasonable default as it requires no additional setup. */
/** Shards by suite the name. This is the most reasonable default as it
* requires no additional setup.
*/
final case object SuiteName extends ShardingAlgorithm {

override def shouldRun(specName: String, shardContext: ShardContext): Boolean =
// TODO: Test whether `hashCode` gets a good distribution. Otherwise implement a different hash algorithm.
specName.hashCode.abs % shardContext.testShardCount == shardContext.testShard
MurmurHash3
.bytesHash(specName.getBytes(StandardCharsets.UTF_8))
.abs % shardContext.testShardCount == shardContext.testShard
}

/** Will always mark the test to run on this shard. Useful for debugging or for fallback algorithms. */
/** Will always mark the test to run on this shard. Useful for debugging or
* for fallback algorithms.
*/
final case object Always extends ShardingAlgorithm {
override def shouldRun(specName: String, shardContext: ShardContext): Boolean = true
}

/** Will never mark the test to run on this shard. Useful for debugging or for fallback algorithms. */
/** Will never mark the test to run on this shard. Useful for debugging or for
* fallback algorithms.
*/
final case object Never extends ShardingAlgorithm {
override def shouldRun(specName: String, shardContext: ShardContext): Boolean = false
}

/** Attempts to balance the shards by execution time so that no one shard takes significantly longer to complete than
* another.
object Balance {

def fromJUnitReports(
reportDirectories: Seq[Path],
shardsInfo: ShardingInfo,
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
): Balance =
ShardingAlgorithm.Balance(
JUnitReportParser.parseDirectoriesRecursively(reportDirectories).testReports.map { r =>
SpecInfo(r.name, Some(Duration.ofMillis(r.timeTaken.toLong)))
},
shardsInfo,
fallbackShardingAlgorithm
)
}

/** Attempts to balance the shards by execution time so that no one shard
* takes significantly longer to complete than another.
*/
final case class Balance(
tests: List[TestSuiteInfo],
bucketCount: Int,
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
specs: Seq[SpecInfo],
shardsInfo: ShardingInfo,
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
) extends ShardingAlgorithm {

// TODO: Median might be better here?
private val averageTime: Option[Duration] = {
val allTimeTaken = tests.flatMap(_.timeTaken)
val allTimeTaken = specs.flatMap(_.timeTaken)

allTimeTaken.reduceOption(_.plus(_)).map { d =>
if (allTimeTaken.isEmpty) Duration.ZERO
else d.dividedBy(allTimeTaken.length)
Expand All @@ -48,14 +79,19 @@ object ShardingAlgorithm {
// TODO: This uses a naive greedy algorithm for partitioning into approximately equal subsets. While this problem
// is NP-complete, there's a lot of room for improvement with other algorithms. Dynamic programming should be
// possible here.
private def createBucketMap(testShardCount: Int) = {
def distributeEvenly: Map[TestSuiteInfoSimple, Int] = {
val durationOrdering: Ordering[Duration] = (a: Duration, b: Duration) => a.compareTo(b)

val allTests = tests
val allTests = specs
.map(t => TestSuiteInfoSimple(t.name, t.timeTaken.getOrElse(averageTime.getOrElse(Duration.ZERO))))
.sortBy(_.timeTaken)(durationOrdering.reverse)

val buckets = Array.fill(testShardCount)(TestBucket(Nil, Duration.ZERO))
val buckets = (0 until shardsInfo.shardCount).map { shardIndex =>
TestBucket(
tests = Nil,
sum = shardsInfo.initialDurations.getOrElse(shardIndex, Duration.ZERO)
)
}.toArray

allTests.foreach { test =>
val minBucket = buckets.minBy(_.sum)
Expand All @@ -66,14 +102,14 @@ object ShardingAlgorithm {

buckets.zipWithIndex.flatMap { case (bucket, i) =>
bucket.tests.map { info =>
info.name -> i
info -> i
}
}.toMap
}

// `bucketCount` doesn't necessary need to match `testShardCount`, but ideally it should be a multiple of it.
// TODO: Maybe print a warning if it's not a multiple of it.
private val bucketMap: Map[String, Int] = createBucketMap(bucketCount)
private val bucketMap: Map[String, Int] = distributeEvenly.map { case (k, v) =>
k.name -> v
}

def shouldRun(specName: String, shardContext: ShardContext): Boolean =
bucketMap.get(specName) match {
Expand All @@ -82,6 +118,7 @@ object ShardingAlgorithm {
}
}

private final case class TestSuiteInfoSimple(name: String, timeTaken: Duration)
private final case class TestBucket(var tests: List[TestSuiteInfoSimple], var sum: Duration)
final case class TestSuiteInfoSimple(name: String, timeTaken: Duration)

final private case class TestBucket(var tests: List[TestSuiteInfoSimple], var sum: Duration)
}
5 changes: 5 additions & 0 deletions src/main/scala/sbttestshards/ShardingInfo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sbttestshards

import java.time.Duration

final case class ShardingInfo(shardCount: Int, initialDurations: Map[Int, Duration] = Map.empty)
5 changes: 5 additions & 0 deletions src/main/scala/sbttestshards/SpecBucketItem.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sbttestshards

import java.time.Duration

final case class SpecBucketItem(name: String, timeTaken: Duration)
5 changes: 5 additions & 0 deletions src/main/scala/sbttestshards/SpecInfo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sbttestshards

import java.time.Duration

final case class SpecInfo(name: String, timeTaken: Option[Duration])
5 changes: 0 additions & 5 deletions src/main/scala/sbttestshards/TestBucketItem.scala

This file was deleted.

15 changes: 8 additions & 7 deletions src/main/scala/sbttestshards/TestShardsPlugin.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package sbttestshards

import sbt.Keys.*
import sbt.*
import sbt.Keys.*

object TestShardsPlugin extends AutoPlugin {

object autoImport {
val testShard = settingKey[Int]("testShard")
val testShardCount = settingKey[Int]("testShardCount")
val testShard = settingKey[Int]("testShard")
val testShardCount = settingKey[Int]("testShardCount")
val shardingAlgorithm = settingKey[ShardingAlgorithm]("shardingAlgorithm")
val testShardDebug = settingKey[Boolean]("testShardDebug")
val testShardDebug = settingKey[Boolean]("testShardDebug")
}

import autoImport.*
Expand All @@ -22,10 +23,10 @@ object TestShardsPlugin extends AutoPlugin {

override lazy val projectSettings: Seq[Def.Setting[?]] =
Seq(
testShard := stringConfig("TEST_SHARD", "0").toInt,
testShardCount := stringConfig("TEST_SHARD_COUNT", "1").toInt,
testShard := stringConfig("TEST_SHARD", "0").toInt,
testShardCount := stringConfig("TEST_SHARD_COUNT", "1").toInt,
shardingAlgorithm := ShardingAlgorithm.SuiteName,
testShardDebug := false,
testShardDebug := false,
Test / testOptions += {
val shardContext = ShardContext(testShard.value, testShardCount.value, sLog.value)
Tests.Filter { specName =>
Expand Down
5 changes: 0 additions & 5 deletions src/main/scala/sbttestshards/TestSuiteInfo.scala

This file was deleted.

Loading

0 comments on commit 0ad08db

Please sign in to comment.