From 63ed564c957b9df28932a1c9343e6f595d667ed8 Mon Sep 17 00:00:00 2001 From: reibitto Date: Tue, 12 Jul 2022 18:52:58 +0900 Subject: [PATCH] Initial commit --- .github/workflows/release.yml | 20 +++++ .github/workflows/scala.yml | 26 ++++++ .gitignore | 28 +++++++ .scalafmt.conf | 18 +++++ README.md | 4 +- build.sbt | 51 ++++++++++++ project/build.properties | 1 + project/plugins.sbt | 3 + .../scala/sbttestshards/ShardContext.scala | 5 ++ .../sbttestshards/ShardingAlgorithm.scala | 79 +++++++++++++++++++ .../scala/sbttestshards/TestBucketItem.scala | 5 ++ .../sbttestshards/TestShardsPlugin.scala | 28 +++++++ .../scala/sbttestshards/TestSuiteInfo.scala | 5 ++ 13 files changed, 272 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/release.yml create mode 100644 .github/workflows/scala.yml create mode 100644 .gitignore create mode 100644 .scalafmt.conf create mode 100644 build.sbt create mode 100644 project/build.properties create mode 100644 project/plugins.sbt create mode 100644 src/main/scala/sbttestshards/ShardContext.scala create mode 100644 src/main/scala/sbttestshards/ShardingAlgorithm.scala create mode 100644 src/main/scala/sbttestshards/TestBucketItem.scala create mode 100644 src/main/scala/sbttestshards/TestShardsPlugin.scala create mode 100644 src/main/scala/sbttestshards/TestSuiteInfo.scala diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..73550b2 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,20 @@ +name: Release +on: + push: + branches: [master, main] + tags: ["*"] +jobs: + publish: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v2.3.4 + with: + fetch-depth: 0 + - uses: olafurpg/setup-scala@v13 + - uses: olafurpg/setup-gpg@v3 + - run: sbt ci-release + env: + PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} + PGP_SECRET: ${{ secrets.PGP_SECRET }} + SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} + SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml new file mode 100644 index 0000000..5bf442a --- /dev/null +++ b/.github/workflows/scala.yml @@ -0,0 +1,26 @@ +name: Scala CI + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + jvm: + strategy: + fail-fast: false + matrix: + scala: [2.12.16] + java: [adopt@1.11, adopt@1.8] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Set up environment + uses: olafurpg/setup-scala@v10 + with: + java-version: ${{ matrix.java }} + + - name: Run tests + run: sbt ++${{ matrix.scala}} fmtCheck test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bc5a6bc --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +target +.idea +.idea_modules +.bloop +.bsp +.metals +/.classpath +/.project +/.settings +/RUNNING_PID +/out/ +*.iws +*.iml +/db +.eclipse +/lib/ +/logs/ +/modules +tmp/ +test-result +server.pid +*.eml +/dist/ +.cache +/reference +local.conf +/logs +publish.sbt diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 0000000..f0b16ee --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1,18 @@ +version = "3.2.1" + +runner.dialect = scala213 + +maxColumn = 120 +align.preset = most +continuationIndent.defnSite = 2 +assumeStandardLibraryStripMargin = true +docstrings.style = SpaceAsterisk +lineEndings = preserve +includeCurlyBraceInSelectChains = false +danglingParentheses.preset = true +spaces { + inImportCurlyBraces = true +} +optIn.annotationNewlines = true + +rewrite.rules = [SortImports, RedundantBraces] diff --git a/README.md b/README.md index 0d8b36a..4f02eec 100644 --- a/README.md +++ b/README.md @@ -1 +1,3 @@ -# sbt-test-shards \ No newline at end of file +# SBT Test Shards + +*An SBT plugin for splitting tests across multiple shards to speed up tests.* diff --git a/build.sbt b/build.sbt new file mode 100644 index 0000000..4199c48 --- /dev/null +++ b/build.sbt @@ -0,0 +1,51 @@ +import sbtwelcome._ + +inThisBuild( + List( + organization := "com.github.reibitto", + homepage := Some(url("https://github.com/reibitto/sbt-test-shards")), + licenses := List("Apache-2.0" -> url("https://www.apache.org/licenses/LICENSE-2.0")), + developers := List( + Developer("reibitto", "reibitto", "reibitto@users.noreply.github.com", url("https://reibitto.github.io")) + ) + ) +) + +lazy val root = (project in file(".")).settings( + name := "sbt-test-shards", + organization := "com.github.reibitto", + scalaVersion := "2.12.16", + sbtPlugin := true +) + +addCommandAlias("fmt", "all root/scalafmtSbt root/scalafmtAll") +addCommandAlias("fmtCheck", "all root/scalafmtSbtCheck root/scalafmtCheckAll") + +logo := + s""" + | ______ _____ + | __________ /___ /_ + | __ ___/_ __ \\ __/ + | _(__ )_ /_/ / /_ + | /____/ /_.___/\\__/ + | _____ _____ ______ _________ + | __ /______________ /_ __________ /_______ _____________ /_______ + | _ __/ _ \\_ ___/ __/ __ ___/_ __ \\ __ `/_ ___/ __ /__ ___/ + | / /_ / __/(__ )/ /_ _(__ )_ / / / /_/ /_ / / /_/ / _(__ ) + | \\__/ \\___//____/ \\__/ /____/ /_/ /_/\\__,_/ /_/ \\__,_/ /____/ + | + |${version.value} + | + |${scala.Console.YELLOW}Scala ${scalaVersion.value}${scala.Console.RESET} + | + |""".stripMargin + +usefulTasks := Seq( + UsefulTask("a", "~compile", "Compile with file-watch enabled"), + UsefulTask("b", "fmt", "Run scalafmt on the entire project"), + UsefulTask("c", "publishLocal", "Publish the sbt plugin locally so that you can consume it from a different project") +) + +logoColor := scala.Console.MAGENTA + +ThisBuild / organization := "com.github.reibitto" diff --git a/project/build.properties b/project/build.properties new file mode 100644 index 0000000..22af262 --- /dev/null +++ b/project/build.properties @@ -0,0 +1 @@ +sbt.version=1.7.1 diff --git a/project/plugins.sbt b/project/plugins.sbt new file mode 100644 index 0000000..18e2828 --- /dev/null +++ b/project/plugins.sbt @@ -0,0 +1,3 @@ +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6") +addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.10") +addSbtPlugin("com.github.reibitto" % "sbt-welcome" % "0.2.2") diff --git a/src/main/scala/sbttestshards/ShardContext.scala b/src/main/scala/sbttestshards/ShardContext.scala new file mode 100644 index 0000000..e9af680 --- /dev/null +++ b/src/main/scala/sbttestshards/ShardContext.scala @@ -0,0 +1,5 @@ +package sbttestshards + +import sbt.Logger + +final case class ShardContext(testShard: Int, testShardCount: Int, logger: Logger) diff --git a/src/main/scala/sbttestshards/ShardingAlgorithm.scala b/src/main/scala/sbttestshards/ShardingAlgorithm.scala new file mode 100644 index 0000000..2914477 --- /dev/null +++ b/src/main/scala/sbttestshards/ShardingAlgorithm.scala @@ -0,0 +1,79 @@ +package sbttestshards + +import java.time.Duration + +// This trait is open so that users can implement a custom `ShardingAlgorithm` if they'd like +trait ShardingAlgorithm { + def isInShard(specName: String, shardContext: ShardContext): Boolean +} + +object ShardingAlgorithm { + final case object Always extends ShardingAlgorithm { + override def isInShard(specName: String, shardContext: ShardContext): Boolean = true + } + + final case object Never extends ShardingAlgorithm { + override def isInShard(specName: String, shardContext: ShardContext): Boolean = false + } + + final case object SuiteName extends ShardingAlgorithm { + override def isInShard(specName: String, shardContext: ShardContext): Boolean = { + val shouldRun = specName.hashCode % shardContext.testShardCount == shardContext.testShard + + println(s"${specName} will run? ${shouldRun}") + + shouldRun + } + } + + final case class Balance( + tests: List[TestSuiteInfo], + bucketCount: Int, + fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName + ) extends ShardingAlgorithm { + // TODO: Median might be better here? + private val averageTime: Option[Duration] = { + val allTimeTaken = tests.flatMap(_.timeTaken) + allTimeTaken.reduceOption(_.plus(_)).map { d => + if (d.isZero) Duration.ZERO + else d.dividedBy(allTimeTaken.length) + } + } + + private final case class TestSuiteInfoSimple(name: String, timeTaken: Duration) + private final case class TestBucket(var tests: List[TestSuiteInfoSimple], var sum: Duration) + + private def createBucketMap(testShardCount: Int) = { + val durationOrdering: Ordering[Duration] = (a: Duration, b: Duration) => a.compareTo(b) + + val allTests = tests + .map(t => TestSuiteInfoSimple(t.name, t.timeTaken.getOrElse(averageTime.getOrElse(Duration.ZERO)))) + .sortBy(_.timeTaken)(durationOrdering.reverse) + + val buckets = Array.fill(testShardCount)(TestBucket(Nil, Duration.ZERO)) + + allTests.foreach { test => + val minBucket = buckets.minBy(_.sum) + + minBucket.tests = test :: minBucket.tests + minBucket.sum = minBucket.sum.plus(test.timeTaken) + } + + buckets.zipWithIndex.flatMap { case (bucket, i) => + bucket.tests.map { info => + info.name -> i + } + }.toMap + } + + // `bucketCount` doesn't necessary need to match `testShardCount`, but ideally it should be a multiple of it. + // TODO: Maybe print a warning if it's not a multiple of it. + private val bucketMap: Map[String, Int] = createBucketMap(bucketCount) + + def isInShard(specName: String, shardContext: ShardContext): Boolean = + bucketMap.get(specName) match { + case Some(bucketIndex) => bucketIndex == shardContext.testShard + case None => fallbackShardingAlgorithm.isInShard(specName, shardContext) + } + } +} diff --git a/src/main/scala/sbttestshards/TestBucketItem.scala b/src/main/scala/sbttestshards/TestBucketItem.scala new file mode 100644 index 0000000..90a598c --- /dev/null +++ b/src/main/scala/sbttestshards/TestBucketItem.scala @@ -0,0 +1,5 @@ +package sbttestshards + +import java.time.Duration + +final case class TestBucketItem(name: String, timeTaken: Duration) diff --git a/src/main/scala/sbttestshards/TestShardsPlugin.scala b/src/main/scala/sbttestshards/TestShardsPlugin.scala new file mode 100644 index 0000000..f03ebde --- /dev/null +++ b/src/main/scala/sbttestshards/TestShardsPlugin.scala @@ -0,0 +1,28 @@ +package sbttestshards + +import sbt.Keys.* +import sbt.* + +object TestShardsPlugin extends AutoPlugin { + object autoImport { + val testShard = settingKey[Int]("testShard") + val testShardCount = settingKey[Int]("testShardCount") + val shardingAlgorithm = settingKey[ShardingAlgorithm]("shardingAlgorithm") + } + + import autoImport.* + + override def trigger = allRequirements + + override lazy val projectSettings: Seq[Def.Setting[?]] = + Seq( + testShard := 0, + testShardCount := 1, + shardingAlgorithm := ShardingAlgorithm.SuiteName, + Test / testOptions += { + val shardContext = ShardContext(testShardCount.value, testShard.value, sLog.value) + Tests.Filter(specName => shardingAlgorithm.value.isInShard(specName, shardContext)) + } + ) + +} diff --git a/src/main/scala/sbttestshards/TestSuiteInfo.scala b/src/main/scala/sbttestshards/TestSuiteInfo.scala new file mode 100644 index 0000000..e9b77f3 --- /dev/null +++ b/src/main/scala/sbttestshards/TestSuiteInfo.scala @@ -0,0 +1,5 @@ +package sbttestshards + +import java.time.Duration + +final case class TestSuiteInfo(name: String, timeTaken: Option[Duration])