From 37a73d50f942f7a4dea9271050ec7deb179e4a48 Mon Sep 17 00:00:00 2001 From: reibitto Date: Thu, 16 Nov 2023 22:25:32 +0900 Subject: [PATCH 1/2] WIP --- .github/workflows/release.yml | 11 ++- .github/workflows/scala.yml | 16 +-- .scalafmt.conf | 50 +++++++--- build.sbt | 24 +++-- project/build.properties | 2 +- project/plugins.sbt | 8 +- .../sbttestshards/ShardingAlgorithm.scala | 51 ++++++---- .../scala/sbttestshards/ShardingInfo.scala | 5 + .../scala/sbttestshards/SpecBucketItem.scala | 5 + src/main/scala/sbttestshards/SpecInfo.scala | 5 + .../scala/sbttestshards/TestBucketItem.scala | 5 - .../sbttestshards/TestShardsPlugin.scala | 15 +-- .../scala/sbttestshards/TestSuiteInfo.scala | 5 - .../parsers/JUnitReportParser.scala | 97 +++++++++++++++++++ .../sbttestshards/ShardingAlgorithmSpec.scala | 53 ++++++++++ 15 files changed, 281 insertions(+), 71 deletions(-) create mode 100644 src/main/scala/sbttestshards/ShardingInfo.scala create mode 100644 src/main/scala/sbttestshards/SpecBucketItem.scala create mode 100644 src/main/scala/sbttestshards/SpecInfo.scala delete mode 100644 src/main/scala/sbttestshards/TestBucketItem.scala delete mode 100644 src/main/scala/sbttestshards/TestSuiteInfo.scala create mode 100644 src/main/scala/sbttestshards/parsers/JUnitReportParser.scala create mode 100644 src/test/scala/sbttestshards/ShardingAlgorithmSpec.scala diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 73550b2..ad28bcf 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -7,14 +7,17 @@ jobs: publish: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v2.3.4 + - uses: actions/checkout@v3 with: fetch-depth: 0 - - uses: olafurpg/setup-scala@v13 - - uses: olafurpg/setup-gpg@v3 + - uses: actions/setup-java@v3 + with: + distribution: temurin + java-version: 8 + cache: sbt - run: sbt ci-release env: PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} PGP_SECRET: ${{ secrets.PGP_SECRET }} SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} - SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} + SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} \ No newline at end of file diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml index 06be443..2d51e2d 100644 --- a/.github/workflows/scala.yml +++ b/.github/workflows/scala.yml @@ -7,16 +7,20 @@ on: branches: [master, main] jobs: - build: + jvm: + strategy: + fail-fast: false + matrix: + scala: [2.12.18] + java: [adopt@1.11, adopt@1.8] runs-on: ubuntu-latest - steps: - uses: actions/checkout@v2 - - name: Set up JDK 11 - uses: actions/setup-java@v1 + - name: Set up environment + uses: olafurpg/setup-scala@v10 with: - java-version: 11 + java-version: ${{ matrix.java }} - name: Run tests - run: sbt -Dscalac.unused.enabled=true fmtCheck test + run: sbt ++${{ matrix.scala}} fmtCheck test \ No newline at end of file diff --git a/.scalafmt.conf b/.scalafmt.conf index f0b16ee..75d2787 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,18 +1,44 @@ -version = "3.2.1" +version = 3.7.14 -runner.dialect = scala213 +runner.dialect = scala213source3 -maxColumn = 120 -align.preset = most -continuationIndent.defnSite = 2 +align = none +maxColumn = 120 # For wide displays. assumeStandardLibraryStripMargin = true +align.tokens = [{code = "=>", owner = "Case"}, {code = "<-"}] +align.allowOverflow = true +align.arrowEnumeratorGenerator = true +align.openParenCallSite = false +align.openParenDefnSite = false +newlines.topLevelStatementBlankLines = [ + {blanks {before = 1}} + {regex = "^Import|^Term.ApplyInfix"} +] +newlines.alwaysBeforeElseAfterCurlyIf = false +indentOperator.topLevelOnly = true docstrings.style = SpaceAsterisk -lineEndings = preserve +docstrings.wrapMaxColumn = 80 + includeCurlyBraceInSelectChains = false -danglingParentheses.preset = true -spaces { - inImportCurlyBraces = true -} -optIn.annotationNewlines = true +includeNoParensInSelectChains = false + +rewrite.rules = [SortModifiers, PreferCurlyFors, SortImports, RedundantBraces] +rewrite.scala3.convertToNewSyntax = true +rewrite.imports.groups = [ + [".*"], + ["java\\..*", "javax\\..*", "scala\\..*"] +] + +project.excludeFilters = ["/target/"] + +lineEndings = preserve -rewrite.rules = [SortImports, RedundantBraces] +fileOverride { + "glob:**/*.sbt" { + runner.dialect = sbt1 + rewrite.scala3.convertToNewSyntax = false + } + "glob:**/project/*.scala" { + rewrite.scala3.convertToNewSyntax = false + } +} \ No newline at end of file diff --git a/build.sbt b/build.sbt index 4199c48..9c584f7 100644 --- a/build.sbt +++ b/build.sbt @@ -1,21 +1,25 @@ -import sbtwelcome._ +import sbtwelcome.* inThisBuild( List( organization := "com.github.reibitto", - homepage := Some(url("https://github.com/reibitto/sbt-test-shards")), - licenses := List("Apache-2.0" -> url("https://www.apache.org/licenses/LICENSE-2.0")), - developers := List( + homepage := Some(url("https://github.com/reibitto/sbt-test-shards")), + licenses := List("Apache-2.0" -> url("https://www.apache.org/licenses/LICENSE-2.0")), + developers := List( Developer("reibitto", "reibitto", "reibitto@users.noreply.github.com", url("https://reibitto.github.io")) ) ) ) lazy val root = (project in file(".")).settings( - name := "sbt-test-shards", + name := "sbt-test-shards", organization := "com.github.reibitto", - scalaVersion := "2.12.16", - sbtPlugin := true + scalaVersion := "2.12.18", + sbtPlugin := true, + libraryDependencies ++= Seq( + "org.scalameta" %% "munit" % "0.7.29" % Test, + "org.scalameta" %% "munit-scalacheck" % "0.7.29" % Test + ) ) addCommandAlias("fmt", "all root/scalafmtSbt root/scalafmtAll") @@ -41,9 +45,9 @@ logo := |""".stripMargin usefulTasks := Seq( - UsefulTask("a", "~compile", "Compile with file-watch enabled"), - UsefulTask("b", "fmt", "Run scalafmt on the entire project"), - UsefulTask("c", "publishLocal", "Publish the sbt plugin locally so that you can consume it from a different project") + UsefulTask("~compile", "Compile with file-watch enabled"), + UsefulTask("fmt", "Run scalafmt on the entire project"), + UsefulTask("publishLocal", "Publish the sbt plugin locally so that you can consume it from a different project") ) logoColor := scala.Console.MAGENTA diff --git a/project/build.properties b/project/build.properties index 22af262..e8a1e24 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.7.1 +sbt.version=1.9.7 diff --git a/project/plugins.sbt b/project/plugins.sbt index 18e2828..1a63a4a 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,3 +1,5 @@ -addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6") -addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.10") -addSbtPlugin("com.github.reibitto" % "sbt-welcome" % "0.2.2") +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2") + +addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.5.12") + +addSbtPlugin("com.github.reibitto" % "sbt-welcome" % "0.4.0") diff --git a/src/main/scala/sbttestshards/ShardingAlgorithm.scala b/src/main/scala/sbttestshards/ShardingAlgorithm.scala index 71f925f..72b4d69 100644 --- a/src/main/scala/sbttestshards/ShardingAlgorithm.scala +++ b/src/main/scala/sbttestshards/ShardingAlgorithm.scala @@ -11,34 +11,43 @@ trait ShardingAlgorithm { object ShardingAlgorithm { - /** Shards by suite the name. This is the most reasonable default as it requires no additional setup. */ + /** Shards by suite the name. This is the most reasonable default as it + * requires no additional setup. + */ final case object SuiteName extends ShardingAlgorithm { + override def shouldRun(specName: String, shardContext: ShardContext): Boolean = // TODO: Test whether `hashCode` gets a good distribution. Otherwise implement a different hash algorithm. specName.hashCode.abs % shardContext.testShardCount == shardContext.testShard } - /** Will always mark the test to run on this shard. Useful for debugging or for fallback algorithms. */ + /** Will always mark the test to run on this shard. Useful for debugging or + * for fallback algorithms. + */ final case object Always extends ShardingAlgorithm { override def shouldRun(specName: String, shardContext: ShardContext): Boolean = true } - /** Will never mark the test to run on this shard. Useful for debugging or for fallback algorithms. */ + /** Will never mark the test to run on this shard. Useful for debugging or for + * fallback algorithms. + */ final case object Never extends ShardingAlgorithm { override def shouldRun(specName: String, shardContext: ShardContext): Boolean = false } - /** Attempts to balance the shards by execution time so that no one shard takes significantly longer to complete than - * another. + /** Attempts to balance the shards by execution time so that no one shard + * takes significantly longer to complete than another. */ final case class Balance( - tests: List[TestSuiteInfo], - bucketCount: Int, - fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName + specs: List[SpecInfo], + shardsInfo: ShardingInfo, + fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName ) extends ShardingAlgorithm { + // TODO: Median might be better here? private val averageTime: Option[Duration] = { - val allTimeTaken = tests.flatMap(_.timeTaken) + val allTimeTaken = specs.flatMap(_.timeTaken) + allTimeTaken.reduceOption(_.plus(_)).map { d => if (allTimeTaken.isEmpty) Duration.ZERO else d.dividedBy(allTimeTaken.length) @@ -48,14 +57,19 @@ object ShardingAlgorithm { // TODO: This uses a naive greedy algorithm for partitioning into approximately equal subsets. While this problem // is NP-complete, there's a lot of room for improvement with other algorithms. Dynamic programming should be // possible here. - private def createBucketMap(testShardCount: Int) = { + def distributeEvenly: Map[TestSuiteInfoSimple, Int] = { val durationOrdering: Ordering[Duration] = (a: Duration, b: Duration) => a.compareTo(b) - val allTests = tests + val allTests = specs .map(t => TestSuiteInfoSimple(t.name, t.timeTaken.getOrElse(averageTime.getOrElse(Duration.ZERO)))) .sortBy(_.timeTaken)(durationOrdering.reverse) - val buckets = Array.fill(testShardCount)(TestBucket(Nil, Duration.ZERO)) + val buckets = (0 until shardsInfo.shardCount).map { shardIndex => + TestBucket( + tests = Nil, + sum = shardsInfo.initialDurations.getOrElse(shardIndex, Duration.ZERO) + ) + }.toArray allTests.foreach { test => val minBucket = buckets.minBy(_.sum) @@ -66,14 +80,14 @@ object ShardingAlgorithm { buckets.zipWithIndex.flatMap { case (bucket, i) => bucket.tests.map { info => - info.name -> i + info -> i } }.toMap } - // `bucketCount` doesn't necessary need to match `testShardCount`, but ideally it should be a multiple of it. - // TODO: Maybe print a warning if it's not a multiple of it. - private val bucketMap: Map[String, Int] = createBucketMap(bucketCount) + private val bucketMap: Map[String, Int] = distributeEvenly.map { case (k, v) => + k.name -> v + } def shouldRun(specName: String, shardContext: ShardContext): Boolean = bucketMap.get(specName) match { @@ -82,6 +96,7 @@ object ShardingAlgorithm { } } - private final case class TestSuiteInfoSimple(name: String, timeTaken: Duration) - private final case class TestBucket(var tests: List[TestSuiteInfoSimple], var sum: Duration) + final case class TestSuiteInfoSimple(name: String, timeTaken: Duration) + + final private case class TestBucket(var tests: List[TestSuiteInfoSimple], var sum: Duration) } diff --git a/src/main/scala/sbttestshards/ShardingInfo.scala b/src/main/scala/sbttestshards/ShardingInfo.scala new file mode 100644 index 0000000..04123a2 --- /dev/null +++ b/src/main/scala/sbttestshards/ShardingInfo.scala @@ -0,0 +1,5 @@ +package sbttestshards + +import java.time.Duration + +final case class ShardingInfo(shardCount: Int, initialDurations: Map[Int, Duration] = Map.empty) diff --git a/src/main/scala/sbttestshards/SpecBucketItem.scala b/src/main/scala/sbttestshards/SpecBucketItem.scala new file mode 100644 index 0000000..6857534 --- /dev/null +++ b/src/main/scala/sbttestshards/SpecBucketItem.scala @@ -0,0 +1,5 @@ +package sbttestshards + +import java.time.Duration + +final case class SpecBucketItem(name: String, timeTaken: Duration) diff --git a/src/main/scala/sbttestshards/SpecInfo.scala b/src/main/scala/sbttestshards/SpecInfo.scala new file mode 100644 index 0000000..e400830 --- /dev/null +++ b/src/main/scala/sbttestshards/SpecInfo.scala @@ -0,0 +1,5 @@ +package sbttestshards + +import java.time.Duration + +final case class SpecInfo(name: String, timeTaken: Option[Duration]) diff --git a/src/main/scala/sbttestshards/TestBucketItem.scala b/src/main/scala/sbttestshards/TestBucketItem.scala deleted file mode 100644 index 90a598c..0000000 --- a/src/main/scala/sbttestshards/TestBucketItem.scala +++ /dev/null @@ -1,5 +0,0 @@ -package sbttestshards - -import java.time.Duration - -final case class TestBucketItem(name: String, timeTaken: Duration) diff --git a/src/main/scala/sbttestshards/TestShardsPlugin.scala b/src/main/scala/sbttestshards/TestShardsPlugin.scala index 8ab68a1..1cf5514 100644 --- a/src/main/scala/sbttestshards/TestShardsPlugin.scala +++ b/src/main/scala/sbttestshards/TestShardsPlugin.scala @@ -1,14 +1,15 @@ package sbttestshards -import sbt.Keys.* import sbt.* +import sbt.Keys.* object TestShardsPlugin extends AutoPlugin { + object autoImport { - val testShard = settingKey[Int]("testShard") - val testShardCount = settingKey[Int]("testShardCount") + val testShard = settingKey[Int]("testShard") + val testShardCount = settingKey[Int]("testShardCount") val shardingAlgorithm = settingKey[ShardingAlgorithm]("shardingAlgorithm") - val testShardDebug = settingKey[Boolean]("testShardDebug") + val testShardDebug = settingKey[Boolean]("testShardDebug") } import autoImport.* @@ -22,10 +23,10 @@ object TestShardsPlugin extends AutoPlugin { override lazy val projectSettings: Seq[Def.Setting[?]] = Seq( - testShard := stringConfig("TEST_SHARD", "0").toInt, - testShardCount := stringConfig("TEST_SHARD_COUNT", "1").toInt, + testShard := stringConfig("TEST_SHARD", "0").toInt, + testShardCount := stringConfig("TEST_SHARD_COUNT", "1").toInt, shardingAlgorithm := ShardingAlgorithm.SuiteName, - testShardDebug := false, + testShardDebug := false, Test / testOptions += { val shardContext = ShardContext(testShard.value, testShardCount.value, sLog.value) Tests.Filter { specName => diff --git a/src/main/scala/sbttestshards/TestSuiteInfo.scala b/src/main/scala/sbttestshards/TestSuiteInfo.scala deleted file mode 100644 index e9b77f3..0000000 --- a/src/main/scala/sbttestshards/TestSuiteInfo.scala +++ /dev/null @@ -1,5 +0,0 @@ -package sbttestshards - -import java.time.Duration - -final case class TestSuiteInfo(name: String, timeTaken: Option[Duration]) diff --git a/src/main/scala/sbttestshards/parsers/JUnitReportParser.scala b/src/main/scala/sbttestshards/parsers/JUnitReportParser.scala new file mode 100644 index 0000000..04d61dd --- /dev/null +++ b/src/main/scala/sbttestshards/parsers/JUnitReportParser.scala @@ -0,0 +1,97 @@ +package sbttestshards.parsers + +import java.nio.file.Files +import java.nio.file.Path +import scala.jdk.CollectionConverters.* +import scala.xml.XML + +final case class FullTestReport(testReports: Seq[SpecTestReport]) { + def specCount: Int = testReports.length + + def testCount: Int = testReports.map(_.testCount).sum + + def allPassed: Boolean = testReports.forall(!_.hasFailures) + + def badTestCount: Int = testReports.map(_.badTestCount).sum + + def hasFailures: Boolean = testReports.exists(_.hasFailures) + + def ++(other: FullTestReport): FullTestReport = FullTestReport(testReports ++ other.testReports) +} + +final case class SpecTestReport( + name: String, + testCount: Int, + errorCount: Int, + failureCount: Int, + skipCount: Int, + timeTaken: Double, + failedTests: Seq[String] +) { + def badTestCount: Int = failureCount + errorCount + + def successCount: Int = testCount - errorCount - failureCount - skipCount + + def hasFailures: Boolean = failedTests.nonEmpty +} + +object JUnitReportParser { + + def listReports(reportDirectory: Path): Iterator[Path] = + Files.list(reportDirectory).iterator().asScala.filter(_.toString.endsWith(".xml")) + + def listReportsRecursively(reportDirectory: Path): Iterator[Path] = + Files.walk(reportDirectory).iterator().asScala.filter(_.toString.endsWith(".xml")) + + def parseDirectories(reportDirectories: Path*): FullTestReport = + reportDirectories.map(parseDirectory).reduceLeft(_ ++ _) + + def parseDirectory(reportDirectory: Path): FullTestReport = + FullTestReport( + listReports(reportDirectory).map { reportFile => + parseReport(reportFile) + }.toSeq + ) + + def parseDirectoryRecursively(reportDirectory: Path): FullTestReport = + FullTestReport( + listReportsRecursively(reportDirectory).map { reportFile => + parseReport(reportFile) + }.toSeq + ) + + def parseReport(reportFile: Path): SpecTestReport = { + val xml = XML.loadFile(reportFile.toFile) + + val specName = xml \@ "name" + val testCount = (xml \@ "tests").toInt + val errorCount = (xml \@ "errors").toInt + val failureCount = (xml \@ "failures").toInt + val skipCount = (xml \@ "skipped").toInt + val timeTaken = (xml \@ "time").toDouble + + val testcaseNodes = xml \ "testcase" + + val failedTests = testcaseNodes.map { node => + if ((node \ "failure").nonEmpty) { + val testName = (node \@ "name").trim + + Some(testName) + } else { + None + } + }.collect { case Some(testName) => + testName + } + + SpecTestReport( + name = specName, + testCount = testCount, + errorCount = errorCount, + failureCount = failureCount, + skipCount = skipCount, + timeTaken = timeTaken, + failedTests = failedTests + ) + } +} diff --git a/src/test/scala/sbttestshards/ShardingAlgorithmSpec.scala b/src/test/scala/sbttestshards/ShardingAlgorithmSpec.scala new file mode 100644 index 0000000..9d7a5a1 --- /dev/null +++ b/src/test/scala/sbttestshards/ShardingAlgorithmSpec.scala @@ -0,0 +1,53 @@ +package sbttestshards + +import munit.ScalaCheckSuite +import org.scalacheck.Arbitrary +import org.scalacheck.Gen +import org.scalacheck.Prop.forAll +import org.scalacheck.Test + +import java.time.Duration + +class ShardingAlgorithmSpec extends ScalaCheckSuite { + + implicit val specInfoArbitrary: Arbitrary[SpecInfo] = Arbitrary { + for { + specName <- Gen.resize(30, Gen.alphaStr) + timeTaken <- Gen.choose(0L, 9 * 60 * 1000) + } yield SpecInfo(specName, Some(Duration.ofMillis(timeTaken))) + } + + override def scalaCheckTestParameters: Test.Parameters = + super.scalaCheckTestParameters.withMinSuccessfulTests(10000) + + property( + "each shard should be balanced so that the difference between two never exceed the maximum single spec time" + ) { + forAll { (tests: List[SpecInfo]) => + val algo = ShardingAlgorithm.Balance(tests, ShardingInfo(3)) + + val maxSpecTime = tests.map(_.timeTaken.getOrElse(Duration.ZERO)).reduceOption { (a, b) => + if (a.compareTo(b) > 0) a else b + } + + val bucketMap = algo.distributeEvenly.toSeq.groupBy(_._2).map { case (k, v) => + k -> v.map(_._1.timeTaken).reduceOption(_.plus(_)).getOrElse(Duration.ZERO) + } + + bucketMap.values.toSeq + .sliding(2) + .map { + case Seq(_) => + true + + case Seq(a, b) => + val difference = a.minus(b).abs() + + difference.compareTo(maxSpecTime.getOrElse(Duration.ZERO)) <= 0 + } + .reduceOption(_ && _) + .getOrElse(true) + } + } + +} From bbf49004dffb4462c457d9414d1a1e3878d94f49 Mon Sep 17 00:00:00 2001 From: reibitto Date: Sat, 18 Nov 2023 21:39:40 +0900 Subject: [PATCH 2/2] Cleanup --- .../sbttestshards/ShardingAlgorithm.scala | 28 +++++++++++++++++-- .../parsers/JUnitReportParser.scala | 23 +++++++-------- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/src/main/scala/sbttestshards/ShardingAlgorithm.scala b/src/main/scala/sbttestshards/ShardingAlgorithm.scala index 72b4d69..b547bef 100644 --- a/src/main/scala/sbttestshards/ShardingAlgorithm.scala +++ b/src/main/scala/sbttestshards/ShardingAlgorithm.scala @@ -1,6 +1,11 @@ package sbttestshards +import sbttestshards.parsers.JUnitReportParser + +import java.nio.charset.StandardCharsets +import java.nio.file.Path import java.time.Duration +import scala.util.hashing.MurmurHash3 // This trait is open so that users can implement a custom `ShardingAlgorithm` if they'd like trait ShardingAlgorithm { @@ -17,8 +22,9 @@ object ShardingAlgorithm { final case object SuiteName extends ShardingAlgorithm { override def shouldRun(specName: String, shardContext: ShardContext): Boolean = - // TODO: Test whether `hashCode` gets a good distribution. Otherwise implement a different hash algorithm. - specName.hashCode.abs % shardContext.testShardCount == shardContext.testShard + MurmurHash3 + .bytesHash(specName.getBytes(StandardCharsets.UTF_8)) + .abs % shardContext.testShardCount == shardContext.testShard } /** Will always mark the test to run on this shard. Useful for debugging or @@ -35,11 +41,27 @@ object ShardingAlgorithm { override def shouldRun(specName: String, shardContext: ShardContext): Boolean = false } + object Balance { + + def fromJUnitReports( + reportDirectories: Seq[Path], + shardsInfo: ShardingInfo, + fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName + ): Balance = + ShardingAlgorithm.Balance( + JUnitReportParser.parseDirectoriesRecursively(reportDirectories).testReports.map { r => + SpecInfo(r.name, Some(Duration.ofMillis(r.timeTaken.toLong))) + }, + shardsInfo, + fallbackShardingAlgorithm + ) + } + /** Attempts to balance the shards by execution time so that no one shard * takes significantly longer to complete than another. */ final case class Balance( - specs: List[SpecInfo], + specs: Seq[SpecInfo], shardsInfo: ShardingInfo, fallbackShardingAlgorithm: ShardingAlgorithm = ShardingAlgorithm.SuiteName ) extends ShardingAlgorithm { diff --git a/src/main/scala/sbttestshards/parsers/JUnitReportParser.scala b/src/main/scala/sbttestshards/parsers/JUnitReportParser.scala index 04d61dd..37b110b 100644 --- a/src/main/scala/sbttestshards/parsers/JUnitReportParser.scala +++ b/src/main/scala/sbttestshards/parsers/JUnitReportParser.scala @@ -43,21 +43,22 @@ object JUnitReportParser { def listReportsRecursively(reportDirectory: Path): Iterator[Path] = Files.walk(reportDirectory).iterator().asScala.filter(_.toString.endsWith(".xml")) - def parseDirectories(reportDirectories: Path*): FullTestReport = - reportDirectories.map(parseDirectory).reduceLeft(_ ++ _) - - def parseDirectory(reportDirectory: Path): FullTestReport = + def parseDirectories(reportDirectories: Seq[Path]): FullTestReport = FullTestReport( - listReports(reportDirectory).map { reportFile => - parseReport(reportFile) - }.toSeq + reportDirectories.flatMap { dir => + listReports(dir).map { reportFile => + parseReport(reportFile) + } + } ) - def parseDirectoryRecursively(reportDirectory: Path): FullTestReport = + def parseDirectoriesRecursively(reportDirectories: Seq[Path]): FullTestReport = FullTestReport( - listReportsRecursively(reportDirectory).map { reportFile => - parseReport(reportFile) - }.toSeq + reportDirectories.flatMap { dir => + listReportsRecursively(dir).map { reportFile => + parseReport(reportFile) + } + } ) def parseReport(reportFile: Path): SpecTestReport = {