Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve balancing with JUnit reports support #1

Merged
merged 2 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ jobs:
publish:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2.3.4
- uses: actions/checkout@v3
with:
fetch-depth: 0
- uses: olafurpg/setup-scala@v13
- uses: olafurpg/setup-gpg@v3
- uses: actions/setup-java@v3
with:
distribution: temurin
java-version: 8
cache: sbt
- run: sbt ci-release
env:
PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }}
PGP_SECRET: ${{ secrets.PGP_SECRET }}
SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }}
SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }}
SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }}
16 changes: 10 additions & 6 deletions .github/workflows/scala.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@ on:
branches: [master, main]

jobs:
build:
jvm:
strategy:
fail-fast: false
matrix:
scala: [2.12.18]
java: [[email protected], [email protected]]
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up JDK 11
uses: actions/setup-java@v1
- name: Set up environment
uses: olafurpg/setup-scala@v10
with:
java-version: 11
java-version: ${{ matrix.java }}

- name: Run tests
run: sbt -Dscalac.unused.enabled=true fmtCheck test
run: sbt ++${{ matrix.scala}} fmtCheck test
50 changes: 38 additions & 12 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
version = "3.2.1"
version = 3.7.14

runner.dialect = scala213
runner.dialect = scala213source3

maxColumn = 120
align.preset = most
continuationIndent.defnSite = 2
align = none
maxColumn = 120 # For wide displays.
assumeStandardLibraryStripMargin = true
align.tokens = [{code = "=>", owner = "Case"}, {code = "<-"}]
align.allowOverflow = true
align.arrowEnumeratorGenerator = true
align.openParenCallSite = false
align.openParenDefnSite = false
newlines.topLevelStatementBlankLines = [
{blanks {before = 1}}
{regex = "^Import|^Term.ApplyInfix"}
]
newlines.alwaysBeforeElseAfterCurlyIf = false
indentOperator.topLevelOnly = true
docstrings.style = SpaceAsterisk
lineEndings = preserve
docstrings.wrapMaxColumn = 80

includeCurlyBraceInSelectChains = false
danglingParentheses.preset = true
spaces {
inImportCurlyBraces = true
}
optIn.annotationNewlines = true
includeNoParensInSelectChains = false

rewrite.rules = [SortModifiers, PreferCurlyFors, SortImports, RedundantBraces]
rewrite.scala3.convertToNewSyntax = true
rewrite.imports.groups = [
[".*"],
["java\\..*", "javax\\..*", "scala\\..*"]
]

project.excludeFilters = ["/target/"]

lineEndings = preserve

rewrite.rules = [SortImports, RedundantBraces]
fileOverride {
"glob:**/*.sbt" {
runner.dialect = sbt1
rewrite.scala3.convertToNewSyntax = false
}
"glob:**/project/*.scala" {
rewrite.scala3.convertToNewSyntax = false
}
}
24 changes: 14 additions & 10 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import sbtwelcome._
import sbtwelcome.*

inThisBuild(
List(
organization := "com.github.reibitto",
homepage := Some(url("https://github.com/reibitto/sbt-test-shards")),
licenses := List("Apache-2.0" -> url("https://www.apache.org/licenses/LICENSE-2.0")),
developers := List(
homepage := Some(url("https://github.com/reibitto/sbt-test-shards")),
licenses := List("Apache-2.0" -> url("https://www.apache.org/licenses/LICENSE-2.0")),
developers := List(
Developer("reibitto", "reibitto", "[email protected]", url("https://reibitto.github.io"))
)
)
)

lazy val root = (project in file(".")).settings(
name := "sbt-test-shards",
name := "sbt-test-shards",
organization := "com.github.reibitto",
scalaVersion := "2.12.16",
sbtPlugin := true
scalaVersion := "2.12.18",
sbtPlugin := true,
libraryDependencies ++= Seq(
"org.scalameta" %% "munit" % "0.7.29" % Test,
"org.scalameta" %% "munit-scalacheck" % "0.7.29" % Test
)
)

addCommandAlias("fmt", "all root/scalafmtSbt root/scalafmtAll")
Expand All @@ -41,9 +45,9 @@ logo :=
|""".stripMargin

usefulTasks := Seq(
UsefulTask("a", "~compile", "Compile with file-watch enabled"),
UsefulTask("b", "fmt", "Run scalafmt on the entire project"),
UsefulTask("c", "publishLocal", "Publish the sbt plugin locally so that you can consume it from a different project")
UsefulTask("~compile", "Compile with file-watch enabled"),
UsefulTask("fmt", "Run scalafmt on the entire project"),
UsefulTask("publishLocal", "Publish the sbt plugin locally so that you can consume it from a different project")
)

logoColor := scala.Console.MAGENTA
Expand Down
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=1.7.1
sbt.version=1.9.7
8 changes: 5 additions & 3 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6")
addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.10")
addSbtPlugin("com.github.reibitto" % "sbt-welcome" % "0.2.2")
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2")

addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.12")

addSbtPlugin("com.github.reibitto" % "sbt-welcome" % "0.4.0")
77 changes: 57 additions & 20 deletions src/main/scala/sbttestshards/ShardingAlgorithm.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package sbttestshards

import sbttestshards.parsers.JUnitReportParser

import java.nio.charset.StandardCharsets
import java.nio.file.Path
import java.time.Duration
import scala.util.hashing.MurmurHash3

// This trait is open so that users can implement a custom `ShardingAlgorithm` if they'd like
trait ShardingAlgorithm {
Expand All @@ -11,34 +16,60 @@ trait ShardingAlgorithm {

object ShardingAlgorithm {

/** Shards by suite the name. This is the most reasonable default as it requires no additional setup. */
/** Shards by suite the name. This is the most reasonable default as it
* requires no additional setup.
*/
final case object SuiteName extends ShardingAlgorithm {

override def shouldRun(specName: String, shardContext: ShardContext): Boolean =
// TODO: Test whether `hashCode` gets a good distribution. Otherwise implement a different hash algorithm.
specName.hashCode.abs % shardContext.testShardCount == shardContext.testShard
MurmurHash3
.bytesHash(specName.getBytes(StandardCharsets.UTF_8))
.abs % shardContext.testShardCount == shardContext.testShard
}

/** Will always mark the test to run on this shard. Useful for debugging or for fallback algorithms. */
/** Will always mark the test to run on this shard. Useful for debugging or
* for fallback algorithms.
*/
final case object Always extends ShardingAlgorithm {
override def shouldRun(specName: String, shardContext: ShardContext): Boolean = true
}

/** Will never mark the test to run on this shard. Useful for debugging or for fallback algorithms. */
/** Will never mark the test to run on this shard. Useful for debugging or for
* fallback algorithms.
*/
final case object Never extends ShardingAlgorithm {
override def shouldRun(specName: String, shardContext: ShardContext): Boolean = false
}

/** Attempts to balance the shards by execution time so that no one shard takes significantly longer to complete than
* another.
object Balance {

def fromJUnitReports(
reportDirectories: Seq[Path],
shardsInfo: ShardingInfo,
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
): Balance =
ShardingAlgorithm.Balance(
JUnitReportParser.parseDirectoriesRecursively(reportDirectories).testReports.map { r =>
SpecInfo(r.name, Some(Duration.ofMillis(r.timeTaken.toLong)))
},
shardsInfo,
fallbackShardingAlgorithm
)
}

/** Attempts to balance the shards by execution time so that no one shard
* takes significantly longer to complete than another.
*/
final case class Balance(
tests: List[TestSuiteInfo],
bucketCount: Int,
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
specs: Seq[SpecInfo],
shardsInfo: ShardingInfo,
fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName
) extends ShardingAlgorithm {

// TODO: Median might be better here?
private val averageTime: Option[Duration] = {
val allTimeTaken = tests.flatMap(_.timeTaken)
val allTimeTaken = specs.flatMap(_.timeTaken)

allTimeTaken.reduceOption(_.plus(_)).map { d =>
if (allTimeTaken.isEmpty) Duration.ZERO
else d.dividedBy(allTimeTaken.length)
Expand All @@ -48,14 +79,19 @@ object ShardingAlgorithm {
// TODO: This uses a naive greedy algorithm for partitioning into approximately equal subsets. While this problem
// is NP-complete, there's a lot of room for improvement with other algorithms. Dynamic programming should be
// possible here.
private def createBucketMap(testShardCount: Int) = {
def distributeEvenly: Map[TestSuiteInfoSimple, Int] = {
val durationOrdering: Ordering[Duration] = (a: Duration, b: Duration) => a.compareTo(b)

val allTests = tests
val allTests = specs
.map(t => TestSuiteInfoSimple(t.name, t.timeTaken.getOrElse(averageTime.getOrElse(Duration.ZERO))))
.sortBy(_.timeTaken)(durationOrdering.reverse)

val buckets = Array.fill(testShardCount)(TestBucket(Nil, Duration.ZERO))
val buckets = (0 until shardsInfo.shardCount).map { shardIndex =>
TestBucket(
tests = Nil,
sum = shardsInfo.initialDurations.getOrElse(shardIndex, Duration.ZERO)
)
}.toArray

allTests.foreach { test =>
val minBucket = buckets.minBy(_.sum)
Expand All @@ -66,14 +102,14 @@ object ShardingAlgorithm {

buckets.zipWithIndex.flatMap { case (bucket, i) =>
bucket.tests.map { info =>
info.name -> i
info -> i
}
}.toMap
}

// `bucketCount` doesn't necessary need to match `testShardCount`, but ideally it should be a multiple of it.
// TODO: Maybe print a warning if it's not a multiple of it.
private val bucketMap: Map[String, Int] = createBucketMap(bucketCount)
private val bucketMap: Map[String, Int] = distributeEvenly.map { case (k, v) =>
k.name -> v
}

def shouldRun(specName: String, shardContext: ShardContext): Boolean =
bucketMap.get(specName) match {
Expand All @@ -82,6 +118,7 @@ object ShardingAlgorithm {
}
}

private final case class TestSuiteInfoSimple(name: String, timeTaken: Duration)
private final case class TestBucket(var tests: List[TestSuiteInfoSimple], var sum: Duration)
final case class TestSuiteInfoSimple(name: String, timeTaken: Duration)

final private case class TestBucket(var tests: List[TestSuiteInfoSimple], var sum: Duration)
}
5 changes: 5 additions & 0 deletions src/main/scala/sbttestshards/ShardingInfo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sbttestshards

import java.time.Duration

final case class ShardingInfo(shardCount: Int, initialDurations: Map[Int, Duration] = Map.empty)
5 changes: 5 additions & 0 deletions src/main/scala/sbttestshards/SpecBucketItem.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sbttestshards

import java.time.Duration

final case class SpecBucketItem(name: String, timeTaken: Duration)
5 changes: 5 additions & 0 deletions src/main/scala/sbttestshards/SpecInfo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sbttestshards

import java.time.Duration

final case class SpecInfo(name: String, timeTaken: Option[Duration])
5 changes: 0 additions & 5 deletions src/main/scala/sbttestshards/TestBucketItem.scala

This file was deleted.

15 changes: 8 additions & 7 deletions src/main/scala/sbttestshards/TestShardsPlugin.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package sbttestshards

import sbt.Keys.*
import sbt.*
import sbt.Keys.*

object TestShardsPlugin extends AutoPlugin {

object autoImport {
val testShard = settingKey[Int]("testShard")
val testShardCount = settingKey[Int]("testShardCount")
val testShard = settingKey[Int]("testShard")
val testShardCount = settingKey[Int]("testShardCount")
val shardingAlgorithm = settingKey[ShardingAlgorithm]("shardingAlgorithm")
val testShardDebug = settingKey[Boolean]("testShardDebug")
val testShardDebug = settingKey[Boolean]("testShardDebug")
}

import autoImport.*
Expand All @@ -22,10 +23,10 @@ object TestShardsPlugin extends AutoPlugin {

override lazy val projectSettings: Seq[Def.Setting[?]] =
Seq(
testShard := stringConfig("TEST_SHARD", "0").toInt,
testShardCount := stringConfig("TEST_SHARD_COUNT", "1").toInt,
testShard := stringConfig("TEST_SHARD", "0").toInt,
testShardCount := stringConfig("TEST_SHARD_COUNT", "1").toInt,
shardingAlgorithm := ShardingAlgorithm.SuiteName,
testShardDebug := false,
testShardDebug := false,
Test / testOptions += {
val shardContext = ShardContext(testShard.value, testShardCount.value, sLog.value)
Tests.Filter { specName =>
Expand Down
5 changes: 0 additions & 5 deletions src/main/scala/sbttestshards/TestSuiteInfo.scala

This file was deleted.

Loading
Loading