From cb554b8ce1f99120508eab0499c8fd994f8c2af2 Mon Sep 17 00:00:00 2001 From: Jorge Vicente Cantero Date: Wed, 26 Jun 2019 13:34:25 +0200 Subject: [PATCH] Add first version of sailgun A scala-based Nailgun client. --- .gitignore | 36 ++ .scalafmt.conf | 1 + LICENSE.md | 202 ++++++++++++ NOTICE.md | 13 + build.sbt | 22 ++ project/BuildPlugin.scala | 187 +++++++++++ project/Dependencies.scala | 19 ++ project/build.properties | 1 + project/build.sbt | 10 + src/main/java/sailgun/Terminal.java | 17 + src/main/scala/sailgun/Client.scala | 28 ++ src/main/scala/sailgun/TcpClient.scala | 74 +++++ src/main/scala/sailgun/logging/Logger.scala | 12 + .../scala/sailgun/logging/SailgunLogger.scala | 15 + src/main/scala/sailgun/protocol/Action.scala | 13 + .../scala/sailgun/protocol/ChunkTypes.scala | 23 ++ .../scala/sailgun/protocol/Defaults.scala | 21 ++ .../scala/sailgun/protocol/Protocol.scala | 310 ++++++++++++++++++ src/main/scala/sailgun/protocol/Streams.scala | 14 + src/test/java/sailgun/utils/SailgunEcho.java | 41 +++ .../java/sailgun/utils/SailgunHeartbeat.java | 62 ++++ src/test/scala/sailgun/BaseSuite.scala | 204 ++++++++++++ src/test/scala/sailgun/SailgunBaseSuite.scala | 260 +++++++++++++++ src/test/scala/sailgun/SailgunSpec.scala | 70 ++++ .../sailgun/logging/RecordingLogger.scala | 79 +++++ .../scala/sailgun/logging/Slf4jAdapter.scala | 177 ++++++++++ src/test/scala/sailgun/utils/Diff.scala | 36 ++ .../scala/sailgun/utils/DiffAssertions.scala | 148 +++++++++ src/test/scala/sailgun/utils/ExitNail.scala | 22 ++ .../utils/SailgunThreadLocalInputStream.scala | 9 + 30 files changed, 2126 insertions(+) create mode 100644 .gitignore create mode 100644 .scalafmt.conf create mode 100644 LICENSE.md create mode 100644 NOTICE.md create mode 100644 build.sbt create mode 100644 project/BuildPlugin.scala create mode 100644 project/Dependencies.scala create mode 100644 project/build.properties create mode 100644 project/build.sbt create mode 100644 src/main/java/sailgun/Terminal.java create mode 100644 src/main/scala/sailgun/Client.scala create mode 100644 src/main/scala/sailgun/TcpClient.scala create mode 100644 src/main/scala/sailgun/logging/Logger.scala create mode 100644 src/main/scala/sailgun/logging/SailgunLogger.scala create mode 100644 src/main/scala/sailgun/protocol/Action.scala create mode 100644 src/main/scala/sailgun/protocol/ChunkTypes.scala create mode 100644 src/main/scala/sailgun/protocol/Defaults.scala create mode 100644 src/main/scala/sailgun/protocol/Protocol.scala create mode 100644 src/main/scala/sailgun/protocol/Streams.scala create mode 100644 src/test/java/sailgun/utils/SailgunEcho.java create mode 100644 src/test/java/sailgun/utils/SailgunHeartbeat.java create mode 100644 src/test/scala/sailgun/BaseSuite.scala create mode 100644 src/test/scala/sailgun/SailgunBaseSuite.scala create mode 100644 src/test/scala/sailgun/SailgunSpec.scala create mode 100644 src/test/scala/sailgun/logging/RecordingLogger.scala create mode 100644 src/test/scala/sailgun/logging/Slf4jAdapter.scala create mode 100644 src/test/scala/sailgun/utils/Diff.scala create mode 100644 src/test/scala/sailgun/utils/DiffAssertions.scala create mode 100644 src/test/scala/sailgun/utils/ExitNail.scala create mode 100644 src/test/scala/sailgun/utils/SailgunThreadLocalInputStream.scala diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e2dd65d --- /dev/null +++ b/.gitignore @@ -0,0 +1,36 @@ +.idea/ +bin/.coursier +bin/.scalafmt* + +# Required because these are the proxies for the sourcedeps +.bridge/ +.zinc/ +.nailgun/ + +# Directory in which to install locally bloop binaries to test them +.devbloop/ + +# zinc uses this local cache for publishing stuff +.ivy2/ + +target/ +.bloop/ + +# Ensime's config and cache +.ensime +.ensime_cache/ +integrations/gradle-bloop/lib + +# The index where we store project mappings for local benchmarks +.local-benchmarks + +bloop-config/ +out/ +.DS_Store +website/build/ +website/build/bloop-gh-pages/ +node_modules/ +package-lock.json +.metals/ +*.lock +benchmark-bridge/corpus/ \ No newline at end of file diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 0000000..4ef5fd1 --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1 @@ +version = "2.0.0-RC4" diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..e97e18c --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Jorge Vicente Cantero + Copyright 2019 EPFL (École Polytechnique Federal de Lausanne) + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/NOTICE.md b/NOTICE.md new file mode 100644 index 0000000..f6ceb5a --- /dev/null +++ b/NOTICE.md @@ -0,0 +1,13 @@ +# Sailgun + +Copyright 2019-2020 Jorge Vicente Cantero +Copyright 2019-2020 Scala Center (EPFL) + +## Nailgun + +[Repository website](https://github.com/scalacenter/nailgun/). Fork from +[facebook/nailgun](https://github.com/facebook/nailgun). + +Licensed under the [Apache 2.0 license](https://github.com/scalacenter/nailgun/blob/master/LICENSE.txt). + +It includes minor modifications to the core algorithms and Nailgun protocol to suit Bloop's needs. \ No newline at end of file diff --git a/build.sbt b/build.sbt new file mode 100644 index 0000000..a705671 --- /dev/null +++ b/build.sbt @@ -0,0 +1,22 @@ +import build.BuildKeys._ +import build.Dependencies + +lazy val sailgun = project + .in(file(".")) + .enablePlugins(GraalVMNativeImagePlugin) + .settings(testSuiteSettings) + .settings( + name := "sailgun", + fork in run in Compile := true, + libraryDependencies ++= Seq( + Dependencies.jna, + Dependencies.jnaPlatform, + Dependencies.slf4jApi + ), + graalVMNativeImageOptions ++= List( + "--no-fallback", + "-H:+ReportExceptionStackTraces", + "-H:Log=registerResource", + "-H:IncludeResources=com/sun/jna/darwin/libjnidispatch.jnilib" + ) + ) diff --git a/project/BuildPlugin.scala b/project/BuildPlugin.scala new file mode 100644 index 0000000..1860fd1 --- /dev/null +++ b/project/BuildPlugin.scala @@ -0,0 +1,187 @@ +package build + +import java.io.File + +import bintray.BintrayKeys +import ch.epfl.scala.sbt.release.Feedback +import com.typesafe.sbt.SbtPgp.{autoImport => Pgp} +import sbt.{AutoPlugin, Def, Keys, PluginTrigger, Plugins, State, Task, ThisBuild} +import sbt.io.IO +import sbt.io.syntax.fileToRichFile +import sbt.librarymanagement.syntax.stringToOrganization +import sbtdynver.GitDescribeOutput +import ch.epfl.scala.sbt.release.ReleaseEarlyPlugin.{autoImport => ReleaseEarlyKeys} + +object BuildPlugin extends AutoPlugin { + import sbt.plugins.JvmPlugin + import sbt.plugins.IvyPlugin + import com.typesafe.sbt.SbtPgp + import ch.epfl.scala.sbt.release.ReleaseEarlyPlugin + import com.lucidchart.sbt.scalafmt.ScalafmtCorePlugin + + override def trigger: PluginTrigger = allRequirements + override def requires: Plugins = + JvmPlugin && ScalafmtCorePlugin && ReleaseEarlyPlugin && SbtPgp && IvyPlugin + val autoImport = BuildKeys + + override def globalSettings: Seq[Def.Setting[_]] = + BuildImplementation.globalSettings + override def buildSettings: Seq[Def.Setting[_]] = + BuildImplementation.buildSettings + override def projectSettings: Seq[Def.Setting[_]] = + BuildImplementation.projectSettings +} + +object BuildKeys { + import sbt.{Reference, RootProject, ProjectRef, BuildRef, file} + + def inProject(ref: Reference)(ss: Seq[Def.Setting[_]]): Seq[Def.Setting[_]] = + sbt.inScope(sbt.ThisScope.in(project = ref))(ss) + + def inProjectRefs(refs: Seq[Reference])(ss: Def.Setting[_]*): Seq[Def.Setting[_]] = + refs.flatMap(inProject(_)(ss)) + + def inCompileAndTest(ss: Def.Setting[_]*): Seq[Def.Setting[_]] = + Seq(sbt.Compile, sbt.Test).flatMap(sbt.inConfig(_)(ss)) + + import sbt.Test + val testSuiteSettings: Seq[Def.Setting[_]] = List( + Keys.testFrameworks += new sbt.TestFramework("utest.runner.Framework"), + Keys.libraryDependencies ++= List( + Dependencies.monix % Test, + Dependencies.utest % Test, + Dependencies.pprint % Test, + Dependencies.nailgun % Test, + Dependencies.difflib % Test, + Dependencies.nailgunExamples % Test, + ) + ) + +} + +object BuildImplementation { + import sbt.{url, file} + import sbt.{Developer, Resolver, Watched, Compile, Test} + import sbtdynver.DynVerPlugin.{autoImport => DynVerKeys} + + def GitHub(org: String, project: String): java.net.URL = + url(s"https://github.com/$org/$project") + def GitHubDev(handle: String, fullName: String, email: String) = + Developer(handle, fullName, email, url(s"https://github.com/$handle")) + + final val globalSettings: Seq[Def.Setting[_]] = Seq( + Keys.cancelable := true, + Keys.testOptions in Test += sbt.Tests.Argument("-oD"), + Keys.publishArtifact in Test := false, + Pgp.pgpPublicRing := { + if (Keys.insideCI.value) file("/drone/.gnupg/pubring.asc") + else Pgp.pgpPublicRing.value + }, + Pgp.pgpSecretRing := { + if (Keys.insideCI.value) file("/drone/.gnupg/secring.asc") + else Pgp.pgpSecretRing.value + } + ) + + private final val ThisRepo = GitHub("jvican", "sailgun") + final val buildSettings: Seq[Def.Setting[_]] = Seq( + Keys.organization := "me.vican.jorge", + Keys.updateOptions := Keys.updateOptions.value.withCachedResolution(true), + Keys.scalaVersion := Dependencies.Scala212Version, + Keys.triggeredMessage := Watched.clearWhenTriggered, + Keys.resolvers := { + val oldResolvers = Keys.resolvers.value + val scalacenterResolver = Resolver.bintrayRepo("jvican", "releases") + (oldResolvers :+ scalacenterResolver).distinct + }, + ReleaseEarlyKeys.releaseEarlyWith := { + // Only tag releases go directly to Maven Central, the rest go to bintray! + val isOnlyTag = DynVerKeys.dynverGitDescribeOutput.value + .map(v => v.commitSuffix.isEmpty && v.dirtySuffix.value.isEmpty) + if (isOnlyTag.getOrElse(false)) ReleaseEarlyKeys.SonatypePublisher + else ReleaseEarlyKeys.BintrayPublisher + }, + BintrayKeys.bintrayOrganization := Some("jvican"), + Keys.startYear := Some(2019), + Keys.autoAPIMappings := true, + Keys.publishMavenStyle := true, + Keys.homepage := Some(ThisRepo), + Keys.licenses := Seq("Apache-2.0" -> url("http://www.apache.org/licenses/LICENSE-2.0")), + Keys.developers := List( + GitHubDev("jvican", "Jorge Vicente Cantero", "jorge@vican.me") + ) + ) + + import sbt.{CrossVersion, compilerPlugin} + final val projectSettings: Seq[Def.Setting[_]] = Seq( + BintrayKeys.bintrayPackage := "sailgun", + BintrayKeys.bintrayRepository := "releases", + // Add some metadata that is useful to see in every on-merge bintray release + BintrayKeys.bintrayPackageLabels := List("client", "nailgun", "server", "scala", "tooling"), + ReleaseEarlyKeys.releaseEarlyPublish := BuildDefaults.releaseEarlyPublish.value, + Keys.scalacOptions := reasonableCompileOptions, + // Legal requirement: license and notice files must be in the published jar + Keys.resources in Compile ++= BuildDefaults.getLicense.value, + Keys.publishArtifact in Test := false, + Keys.publishArtifact in (Compile, Keys.packageDoc) := { + val output = DynVerKeys.dynverGitDescribeOutput.value + val version = Keys.version.value + BuildDefaults.publishDocAndSourceArtifact(output, version) + }, + Keys.publishArtifact in (Compile, Keys.packageSrc) := { + val output = DynVerKeys.dynverGitDescribeOutput.value + val version = Keys.version.value + BuildDefaults.publishDocAndSourceArtifact(output, version) + }, + Keys.publishLocalConfiguration in Compile := + Keys.publishLocalConfiguration.value.withOverwrite(true) + ) + + final val reasonableCompileOptions = ( + "-deprecation" :: "-encoding" :: "UTF-8" :: "-feature" :: "-language:existentials" :: + "-language:higherKinds" :: "-language:implicitConversions" :: "-unchecked" :: "-Yno-adapted-args" :: + "-Ywarn-numeric-widen" :: "-Ywarn-value-discard" :: "-Xfuture" :: Nil + ) + + object BuildDefaults { + val releaseEarlyPublish: Def.Initialize[Task[Unit]] = Def.task { + val logger = Keys.streams.value.log + val name = Keys.name.value + // We force publishSigned for all of the modules, yes or yes. + if (ReleaseEarlyKeys.releaseEarlyWith.value == ReleaseEarlyKeys.SonatypePublisher) { + logger.info(Feedback.logReleaseSonatype(name)) + } else { + logger.info(Feedback.logReleaseBintray(name)) + } + + Pgp.PgpKeys.publishSigned.value + } + + // From sbt-sensible https://gitlab.com/fommil/sbt-sensible/issues/5, legal requirement + val getLicense: Def.Initialize[Task[Seq[File]]] = Def.task { + val orig = (Keys.resources in Compile).value + val base = Keys.baseDirectory.value + val root = (Keys.baseDirectory in ThisBuild).value + + def fileWithFallback(name: String): File = + if ((base / name).exists) base / name + else if ((root / name).exists) root / name + else throw new IllegalArgumentException(s"legal file $name must exist") + + Seq(fileWithFallback("LICENSE.md"), fileWithFallback("NOTICE.md")) + } + + /** + * This setting figures out whether the version is a snapshot or not and configures + * the source and doc artifacts that are published by the build. + * + * Snapshot is a term with no clear definition. In this code, a snapshot is a revision + * that is dirty, e.g. has time metadata in its representation. In those cases, the + * build will not publish doc and source artifacts by any of the publishing actions. + */ + def publishDocAndSourceArtifact(info: Option[GitDescribeOutput], version: String): Boolean = { + val isStable = info.map(_.dirtySuffix.value.isEmpty) + !isStable.exists(stable => !stable || version.endsWith("-SNAPSHOT")) + } + } +} diff --git a/project/Dependencies.scala b/project/Dependencies.scala new file mode 100644 index 0000000..b4a5232 --- /dev/null +++ b/project/Dependencies.scala @@ -0,0 +1,19 @@ +package build + +object Dependencies { + import sbt.librarymanagement.syntax.stringToOrganization + val Scala212Version = "2.12.8" + val jnaVersion = "4.5.0" + val nailgunVersion = "ee3c4343" + val difflibVersion = "1.3.0" + + val monix = "io.monix" %% "monix" % "2.3.3" + val utest = "com.lihaoyi" %% "utest" % "0.6.6" + val pprint = "com.lihaoyi" %% "pprint" % "0.5.3" + val jna = "net.java.dev.jna" % "jna" % jnaVersion + val slf4jApi = "org.slf4j" % "slf4j-api" % "1.7.26" + val jnaPlatform = "net.java.dev.jna" % "jna-platform" % jnaVersion + val nailgun = "ch.epfl.scala" % "nailgun-server" % nailgunVersion + val nailgunExamples = "ch.epfl.scala" % "nailgun-examples" % nailgunVersion + val difflib = "com.googlecode.java-diff-utils" % "diffutils" % difflibVersion +} diff --git a/project/build.properties b/project/build.properties new file mode 100644 index 0000000..c0bab04 --- /dev/null +++ b/project/build.properties @@ -0,0 +1 @@ +sbt.version=1.2.8 diff --git a/project/build.sbt b/project/build.sbt new file mode 100644 index 0000000..13b267a --- /dev/null +++ b/project/build.sbt @@ -0,0 +1,10 @@ +val `sailgun-build` = project + .in(file(".")) + .settings( + scalaVersion := "2.12.8", + addSbtPlugin("com.dwijnand" % "sbt-dynver" % "3.1.0"), + addSbtPlugin("com.lucidchart" % "sbt-scalafmt" % "1.14"), + addSbtPlugin("ch.epfl.scala" % "sbt-release-early" % "2.1.1"), + addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.1.0-M13-2"), + addSbtPlugin("com.typesafe.sbt" % "sbt-native-packager" % "1.3.22") + ) diff --git a/src/main/java/sailgun/Terminal.java b/src/main/java/sailgun/Terminal.java new file mode 100644 index 0000000..65bff46 --- /dev/null +++ b/src/main/java/sailgun/Terminal.java @@ -0,0 +1,17 @@ +package sailgun; + +import com.sun.jna.Native; +import com.sun.jna.Library; + +public class Terminal { + private static Libc libc; + public static int hasTerminalAttached(int fd) { + libc = (Libc)Native.loadLibrary("c", Libc.class); + return libc.isatty(fd); + } + + public static interface Libc extends Library { + // unistd.h + int isatty (int fd); + } +} diff --git a/src/main/scala/sailgun/Client.scala b/src/main/scala/sailgun/Client.scala new file mode 100644 index 0000000..02808dc --- /dev/null +++ b/src/main/scala/sailgun/Client.scala @@ -0,0 +1,28 @@ +package sailgun + +import sailgun.logging.Logger +import sailgun.protocol.Streams +import sailgun.protocol.Defaults + +import java.nio.file.Path +import java.util.concurrent.atomic.AtomicBoolean + +abstract class Client { + def run( + cmd: String, + args: Array[String], + cwd: Path, + env: Map[String, String], + streams: Streams, + logger: Logger, + stop: AtomicBoolean + ): Int + + def run( + cmd: String, + args: Array[String], + streams: Streams, + logger: Logger, + stop: AtomicBoolean + ): Int = run(cmd, args, Defaults.cwd, Defaults.env, streams, logger, stop) +} diff --git a/src/main/scala/sailgun/TcpClient.scala b/src/main/scala/sailgun/TcpClient.scala new file mode 100644 index 0000000..0315a66 --- /dev/null +++ b/src/main/scala/sailgun/TcpClient.scala @@ -0,0 +1,74 @@ +package sailgun + +import sailgun.logging.Logger +import sailgun.logging.SailgunLogger +import sailgun.protocol.Defaults +import sailgun.protocol.Protocol +import sailgun.protocol.Streams + +import java.net.Socket +import java.nio.file.Paths +import java.nio.file.Path +import java.io.PrintStream +import java.io.InputStream +import java.net.InetAddress +import java.util.concurrent.atomic.AtomicBoolean +import java.net.SocketException + +class TcpClient(addr: InetAddress, port: Int) extends Client { + def run( + cmd: String, + args: Array[String], + cwd: Path, + env: Map[String, String], + streams: Streams, + logger: Logger, + stop: AtomicBoolean + ): Int = { + val socket = new Socket(addr, port) + try { + val in = socket.getInputStream() + val out = socket.getOutputStream() + val protocol = new Protocol(streams, cwd, env, logger, stop) + protocol.sendCommand(cmd, args, out, in) + } finally { + try { + if (socket.isClosed()) () + else { + try socket.shutdownInput() + finally { + try socket.shutdownOutput() + finally socket.close() + } + } + } catch { + case _: SocketException => () + } + } + } +} + +object TcpClient { + def apply(host: String, port: Int): TcpClient = { + new TcpClient(InetAddress.getByName(host), port) + } + + def main(args: Array[String]): Unit = { + val client = TcpClient(Defaults.Host, Defaults.Port) + val streams = Streams(System.in, System.out, System.err) + val logger = new SailgunLogger("tcp-logger", System.out, isVerbose = false) + + val code = client.run( + "about", + new Array(0), + Defaults.cwd, + Defaults.env, + streams, + logger, + new AtomicBoolean(false) + ) + + logger.debug(s"Return code is $code") + System.exit(code) + } +} diff --git a/src/main/scala/sailgun/logging/Logger.scala b/src/main/scala/sailgun/logging/Logger.scala new file mode 100644 index 0000000..fd84452 --- /dev/null +++ b/src/main/scala/sailgun/logging/Logger.scala @@ -0,0 +1,12 @@ +package sailgun.logging + +abstract class Logger { + val name: String + val isVerbose: Boolean + + def debug(msg: String): Unit + def error(msg: String): Unit + def warn(msg: String): Unit + def info(msg: String): Unit + def trace(exception: Throwable): Unit +} diff --git a/src/main/scala/sailgun/logging/SailgunLogger.scala b/src/main/scala/sailgun/logging/SailgunLogger.scala new file mode 100644 index 0000000..d402773 --- /dev/null +++ b/src/main/scala/sailgun/logging/SailgunLogger.scala @@ -0,0 +1,15 @@ +package sailgun.logging + +import java.io.PrintStream + +class SailgunLogger( + override val name: String, + out: PrintStream, + override val isVerbose: Boolean +) extends Logger { + def debug(msg: String): Unit = out.println(s"debug: $msg") + def error(msg: String): Unit = out.println(s"error: $msg") + def warn(msg: String): Unit = out.println(s"warn: $msg") + def info(msg: String): Unit = out.println(s"$msg") + def trace(exception: Throwable): Unit = exception.printStackTrace(out) +} diff --git a/src/main/scala/sailgun/protocol/Action.scala b/src/main/scala/sailgun/protocol/Action.scala new file mode 100644 index 0000000..277fe15 --- /dev/null +++ b/src/main/scala/sailgun/protocol/Action.scala @@ -0,0 +1,13 @@ +package sailgun.protocol + +import java.io.OutputStream + +sealed trait Action +object Action { + final case object SendStdin extends Action + final case class Exit(code: Int) extends Action + final case class ExitForcefully(error: Throwable) extends Action + final case class Print(bytes: Array[Byte], out: OutputStream) extends Action { + override def toString: String = s"Print(${bytes.toString})" + } +} diff --git a/src/main/scala/sailgun/protocol/ChunkTypes.scala b/src/main/scala/sailgun/protocol/ChunkTypes.scala new file mode 100644 index 0000000..2318b25 --- /dev/null +++ b/src/main/scala/sailgun/protocol/ChunkTypes.scala @@ -0,0 +1,23 @@ +package sailgun.protocol + +import java.nio.charset.StandardCharsets + +object ChunkTypes { + sealed abstract class ChunkType(repr: String) { + lazy val toByteRepr: Byte = + repr.getBytes(StandardCharsets.US_ASCII).apply(0) + } + + final case object Stdin extends ChunkType("0") + final case object Stdout extends ChunkType("1") + final case object Stderr extends ChunkType("2") + final case object StdinEOF extends ChunkType(".") + final case object SendInput extends ChunkType("S") + final case object Heartbeat extends ChunkType("H") + final case object Environment extends ChunkType("E") + final case object Directory extends ChunkType("D") + final case object Command extends ChunkType("C") + final case object Argument extends ChunkType("A") + final case object LongArgument extends ChunkType("L") + final case object Exit extends ChunkType("X") +} diff --git a/src/main/scala/sailgun/protocol/Defaults.scala b/src/main/scala/sailgun/protocol/Defaults.scala new file mode 100644 index 0000000..3224778 --- /dev/null +++ b/src/main/scala/sailgun/protocol/Defaults.scala @@ -0,0 +1,21 @@ +package sailgun.protocol + +import java.nio.file.Paths + +object Defaults { + val Version = "0.9.3" + val Host = "127.0.0.1" + val Port = 8313 + + val env: Map[String, String] = { + import scala.collection.JavaConverters._ + System.getenv().asScala.toMap + } + + val cwd = Paths.get(System.getProperty("user.dir")) + + object Time { + val DefaultHeartbeatIntervalMillis = 500.toLong + val SendThreadWaitTerminationMillis = 5000.toLong + } +} diff --git a/src/main/scala/sailgun/protocol/Protocol.scala b/src/main/scala/sailgun/protocol/Protocol.scala new file mode 100644 index 0000000..d5f915d --- /dev/null +++ b/src/main/scala/sailgun/protocol/Protocol.scala @@ -0,0 +1,310 @@ +package sailgun.protocol + +import sailgun.Terminal +import sailgun.logging.Logger + +import java.net.Socket +import java.io.OutputStream +import java.io.PrintStream +import java.io.InputStream +import java.io.DataOutputStream +import java.io.DataInputStream +import java.io.EOFException +import java.io.InputStreamReader +import java.io.BufferedReader + +import java.nio.charset.StandardCharsets +import java.nio.file.Path +import java.nio.file.Paths +import java.nio.file.Files +import java.nio.ByteBuffer + +import java.util.concurrent.Semaphore +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger + +import scala.util.Try +import scala.util.Failure +import scala.util.Success +import scala.util.control.NonFatal + +/** + * An implementation of the nailgun protocol in Scala. + * + * It follows http://www.martiansoftware.com/nailgun/protocol.html and has + * been slightly inspired in the C and Python clients. The implementation has + * been simplified more than these two and optimized for readability. + * + * The protocol is designed to be used by different instances of + * [[sailgun.Client]] implementing different communication mechanisms (e.g. + * TCP / Unix Domain sockets / Windows Named Pipes). + */ +class Protocol( + streams: Streams, + cwd: Path, + environment: Map[String, String], + logger: Logger, + stopFurtherProcessing: AtomicBoolean +) { + private val absoluteCwd = cwd.toAbsolutePath().toString + private val exitCode: AtomicInteger = new AtomicInteger(-1) + private val isRunning: AtomicBoolean = new AtomicBoolean(false) + private val anyThreadFailed: AtomicBoolean = new AtomicBoolean(false) + private val sendStdinSemaphore: Semaphore = new Semaphore(0) + private val waitTermination: Semaphore = new Semaphore(0) + + val NailgunFileSeparator = java.io.File.separator + val NailgunPathSeparator = java.io.File.pathSeparator + def allEnvironment: Map[String, String] = { + def interactive(fd: Int): String = + Integer.toString(Terminal.hasTerminalAttached(fd)) + def skipIfNative(f: => String) = + if (System.getProperty("java.vm.name") == "Substrate VM") "0" else f + environment ++ Map( + "NAILGUN_FILESEPARATOR" -> NailgunFileSeparator, + "NAILGUN_PATHSEPARATOR" -> NailgunPathSeparator, + "NAILGUN_TTY_0" -> skipIfNative(interactive(0)), + "NAILGUN_TTY_1" -> skipIfNative(interactive(1)), + "NAILGUN_TTY_2" -> skipIfNative(interactive(2)) + ) + } + + def sendCommand( + cmd: String, + cmdArgs: Array[String], + out0: OutputStream, + in0: InputStream + ): Int = { + isRunning.set(true) + val in = new DataInputStream(in0) + val out = new DataOutputStream(out0) + + val sendStdin = createStdinThread(out) + val scheduleHeartbeat = createHeartbeatAndShutdownThread(in, out) + // Start heartbeat thread before sending command as python and C clients do + scheduleHeartbeat.start() + + try { + // Send client command's environment to Nailgun server + logger.debug("Sending arguments to Nailgun server") + cmdArgs.foreach(sendChunk(ChunkTypes.Argument, _, out)) + logger.debug("Sending environment variables to Nailgun server") + allEnvironment.foreach( + kv => sendChunk(ChunkTypes.Environment, s"${kv._1}=${kv._2}", out) + ) + logger.debug("Sending current working directory to Nailgun server") + sendChunk(ChunkTypes.Directory, absoluteCwd, out) + logger.debug("Sending command to Nailgun server") + sendChunk(ChunkTypes.Command, cmd, out) + logger.debug("Finished sending command information to Nailgun server") + + // Start thread sending stdin right after sending command + logger.debug("Starting thread to read stdin...") + sendStdin.start() + + while (exitCode.get() == -1) { + val action = processChunkFromServer(in) + logger.debug(s"Received action $action from Nailgun server") + action match { + case Action.Exit(code) => + exitCode.compareAndSet(-1, code) + case Action.ExitForcefully(error) => + exitCode.compareAndSet(-1, 1) + printException(error) + case Action.Print(bytes, out) => out.write(bytes) + case Action.SendStdin => sendStdinSemaphore.release() + } + } + } catch { + case NonFatal(exception) => + exitCode.compareAndSet(-1, 1) + if (!stopFurtherProcessing.get()) { + printException(exception) + } + } finally { + // Always disable `isRunning` when client finishes the command execution + isRunning.compareAndSet(true, false) + // Release with max to guarantee all `acquire` return + waitTermination.release(Int.MaxValue) + // Release stdin semaphore if `acquire` was done by `sendStdin` thread + sendStdinSemaphore.release(Int.MaxValue) + } + + if (stopFurtherProcessing.get()) { + sendStdin.interrupt() + } + + logger.debug("Waiting for stdin thread to finish...") + sendStdin.join() + logger.debug("Waiting for heartbeat thread to finish...") + scheduleHeartbeat.join() + logger.debug("Returning exit code...") + exitCode.get() + } + + def sendChunk( + tpe: ChunkTypes.ChunkType, + msg: String, + out: DataOutputStream + ): Unit = { + val payload = msg.getBytes(StandardCharsets.UTF_8) + out.writeInt(payload.length) + out.writeByte(tpe.toByteRepr.toInt) + out.write(payload) + out.flush() + } + + def processChunkFromServer(in: DataInputStream): Action = { + def readPayload(length: Int, in: DataInputStream): Array[Byte] = { + var total: Int = 0 + val bytes = new Array[Byte](length) + while (total < length) { + val read = in.read(bytes, total, length - total) + if (read < 0) { + // Error before reaching goal of read bytes + throw new EOFException("Couldn't read bytes from server") + } else { + total += read + } + } + bytes + } + + val readAction = Try { + val bytesToRead = in.readInt() + val chunkType = in.readByte() + chunkType match { + case ChunkTypes.SendInput.toByteRepr => + Action.SendStdin + case ChunkTypes.Stdout.toByteRepr => + Action.Print(readPayload(bytesToRead, in), streams.out) + case ChunkTypes.Stderr.toByteRepr => + Action.Print(readPayload(bytesToRead, in), streams.err) + case ChunkTypes.Exit.toByteRepr => + val bytes = readPayload(bytesToRead, in) + val code = + Integer.parseInt(new String(bytes, StandardCharsets.US_ASCII)) + Action.Exit(code) + case _ => + val error = new RuntimeException(s"Unexpected chunk type: $chunkType") + Action.ExitForcefully(error) + } + } + + readAction match { + case Success(action) => action + case Failure(exception) => Action.ExitForcefully(exception) + } + } + + def createHeartbeatAndShutdownThread( + in: DataInputStream, + out: DataOutputStream + ): Thread = { + daemonThread { () => + var continue: Boolean = true + while (continue) { + val acquired = waitTermination.tryAcquire( + Defaults.Time.DefaultHeartbeatIntervalMillis, + TimeUnit.MILLISECONDS + ) + if (acquired) { + continue = false + } else { + swallowExceptionsIfServerFinished { + if (stopFurtherProcessing.get()) { + out.synchronized { + out.flush() + try in.close() + finally out.close() + } + } + out.synchronized { + sendChunk(ChunkTypes.Heartbeat, "", out) + } + } + } + } + } + } + + def createStdinThread(out: DataOutputStream): Thread = { + daemonThread { () => + val reader = new BufferedReader(new InputStreamReader(streams.in)) + def shouldStop = !isRunning.get() || stopFurtherProcessing.get() + try { + var continue: Boolean = true + while (continue) { + if (shouldStop) { + continue = false + } else { + // Don't start sending input until SendStdin action is received from server + sendStdinSemaphore.acquire() + if (shouldStop) { + continue = false + } else { + val line = reader.readLine() + if (shouldStop) { + continue = false + } else if (line.length() == 0) { + () // Ignore if read line is empty + } else { + swallowExceptionsIfServerFinished { + out.synchronized { + if (line == null) sendChunk(ChunkTypes.StdinEOF, "", out) + else sendChunk(ChunkTypes.Stdin, line, out) + } + } + } + } + } + } + } finally reader.close() + } + } + + /** + * Swallows any exception thrown by the closure [[f]] if client exits before + * the timeout of [[Protocol.Time.SendThreadWaitTerminationMillis]]. + * + * Ignoring exceptions in this scenario makes sense (exception could have + * been caught by server finishing connection with client concurrently). + */ + private def swallowExceptionsIfServerFinished(f: => Unit): Unit = { + try f + catch { + case NonFatal(exception) => + // Should always be false while client waits for exit code from server + val acquired = waitTermination.tryAcquire( + Defaults.Time.SendThreadWaitTerminationMillis, + TimeUnit.MILLISECONDS + ) + + // Ignore exception if in less than the wait the client exited + if (acquired) () + else throw exception + } + } + + private def printException(exception: Throwable): Unit = { + logger.error("Unexpected error forces client exit!") + logger.trace(exception) + } + + private def daemonThread(run0: () => Unit): Thread = { + val t = new Thread { + override def run(): Unit = { + try run0() + catch { + case NonFatal(exception) => + if (anyThreadFailed.compareAndSet(false, true)) { + printException(exception) + } + } + } + } + t.setDaemon(true) + t + } +} diff --git a/src/main/scala/sailgun/protocol/Streams.scala b/src/main/scala/sailgun/protocol/Streams.scala new file mode 100644 index 0000000..dcd05a3 --- /dev/null +++ b/src/main/scala/sailgun/protocol/Streams.scala @@ -0,0 +1,14 @@ +package sailgun.protocol + +import java.io.InputStream +import java.io.OutputStream + +/** + * An instance of user-defined streams where the protocol will forward any + * stdout, stdin or stderr coming from the client. + * + * Note that this is decoupled from the logger API, which is mostly used for + * tracing the protocol behaviour and reporting errors. The logger can be + * backed by some of these user-defined streams but it isn't a requirement. + */ +case class Streams(in: InputStream, out: OutputStream, err: OutputStream) diff --git a/src/test/java/sailgun/utils/SailgunEcho.java b/src/test/java/sailgun/utils/SailgunEcho.java new file mode 100644 index 0000000..cbf230b --- /dev/null +++ b/src/test/java/sailgun/utils/SailgunEcho.java @@ -0,0 +1,41 @@ +/* + + Copyright 2004-2012, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +package com.martiansoftware.nailgun.examples; + +/** + * Echos everything it reads from System.in to System.out. + * + * @author Marty Lamb + */ +public class SailgunEcho { + public static void main(String[] args) throws Exception { + byte[] b = new byte[1024]; + int bytesRead = System.in.read(b); + boolean exit = false; + while (!exit && bytesRead != -1) { + String msg = new String(b, 0, bytesRead); + if (msg.equals("exit")) { + exit = true; + } else { + System.out.write(b, 0, bytesRead); + bytesRead = System.in.read(b); + } + } + } +} diff --git a/src/test/java/sailgun/utils/SailgunHeartbeat.java b/src/test/java/sailgun/utils/SailgunHeartbeat.java new file mode 100644 index 0000000..aa43fbc --- /dev/null +++ b/src/test/java/sailgun/utils/SailgunHeartbeat.java @@ -0,0 +1,62 @@ +/* + + Copyright 2004-2012, Jim Purbrick. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +*/ + +package sailgun.utils; + +import com.martiansoftware.nailgun.NGContext; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Print H for each heartbeat received + */ +public class SailgunHeartbeat { + + public static void nailMain(final NGContext context) { + long runTimeout = Long.MAX_VALUE; + String[] args = context.getArgs(); + if (args.length > 0) { + // first argument is the number of milliseconds to run a command + // if omitted it will never interrupt by itself + try { + runTimeout = Long.parseUnsignedLong(args[0]); + } catch (Exception e) {} + } + + try { + Object lock = new Object(); + AtomicBoolean shutdown = new AtomicBoolean(false); + + context.addClientListener(reason -> { + synchronized (lock) { + shutdown.set(true); + lock.notifyAll(); + } + }); + + context.addHeartbeatListener(() -> context.out.print("H")); + + synchronized (lock) { + if (!shutdown.get()) { + lock.wait(runTimeout); + } + } + } catch (InterruptedException ignored) { + System.out.println("Error code is 42"); + } + } +} \ No newline at end of file diff --git a/src/test/scala/sailgun/BaseSuite.scala b/src/test/scala/sailgun/BaseSuite.scala new file mode 100644 index 0000000..1952920 --- /dev/null +++ b/src/test/scala/sailgun/BaseSuite.scala @@ -0,0 +1,204 @@ +package sailgun + +import sailgun.utils.{Diff, DiffAssertions} + +import scala.concurrent.Await +import scala.concurrent.Future +import scala.concurrent.duration.Duration +import scala.language.experimental.macros +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import utest.TestSuite +import utest.Tests +import utest.asserts.Asserts +import utest.framework.Formatter +import utest.framework.TestCallTree +import utest.framework.Tree +import utest.ufansi.Attrs +import utest.ufansi.Str +import utest.ufansi.Color + +import monix.eval.Task +import java.{util => ju} + +class BaseSuite extends TestSuite { + val pprint = _root_.pprint.PPrinter.BlackWhite + val OS = System.getProperty("os.name").toLowerCase(ju.Locale.ENGLISH) + val isWindows: Boolean = OS.contains("windows") + + def isAppveyor: Boolean = "True" == System.getenv("APPVEYOR") + def beforeAll(): Unit = () + def afterAll(): Unit = () + def intercept[T: ClassTag](exprs: Unit): T = macro Asserts.interceptProxy[T] + + def assertNotEmpty(string: String): Unit = { + if (string.isEmpty) { + fail( + s"expected non-empty string, obtained empty string." + ) + } + } + + def assertEmpty(string: String): Unit = { + if (!string.isEmpty) { + fail( + s"expected empty string, obtained: $string" + ) + } + } + + def assertContains(string: String, substring: String): Unit = { + assert(string.contains(substring)) + } + + def assertNotContains(string: String, substring: String): Unit = { + assert(!string.contains(substring)) + } + + def assert(exprs: Boolean*): Unit = macro Asserts.assertProxy + + def assertNotEquals[T](obtained: T, expected: T, hint: String = ""): Unit = { + if (obtained == expected) { + val hintMsg = if (hint.isEmpty) "" else s" (hint: $hint)" + assertNoDiff(obtained.toString, expected.toString, hint) + fail(s"obtained=<$obtained> == expected=<$expected>$hintMsg") + } + } + + def assertEquals[T](obtained: T, expected: T, hint: String = ""): Unit = { + if (obtained != expected) { + val hintMsg = if (hint.isEmpty) "" else s" (hint: $hint)" + assertNoDiff(obtained.toString, expected.toString, hint) + fail(s"obtained=<$obtained> != expected=<$expected>$hintMsg") + } + } + + def assertNoDiff( + obtained: String, + expected: String, + obtainedTitle: String, + expectedTitle: String + )(implicit filename: sourcecode.File, line: sourcecode.Line): Unit = { + colored { + DiffAssertions.assertNoDiffOrPrintExpected( + obtained, + expected, + obtainedTitle, + expectedTitle, + true + ) + () + } + } + + def assertNoDiff( + obtained: String, + expected: String, + title: String = "", + print: Boolean = true + )(implicit filename: sourcecode.File, line: sourcecode.Line): Unit = { + colored { + DiffAssertions.assertNoDiffOrPrintExpected(obtained, expected, title, title, print) + () + } + } + + def colored[T]( + thunk: => T + )(implicit filename: sourcecode.File, line: sourcecode.Line): T = { + try { + thunk + } catch { + case scala.util.control.NonFatal(e) => + val message = e.getMessage.linesIterator + .map { line => + if (line.startsWith("+")) Color.Green(line) + else if (line.startsWith("-")) Color.LightRed(line) + else Color.Reset(line) + } + .mkString("\n") + val location = s"failed assertion at ${filename.value}:${line.value}\n" + throw new DiffAssertions.TestFailedException(location + message) + } + } + + import monix.execution.CancelableFuture + import java.util.concurrent.TimeUnit + import scala.concurrent.duration.FiniteDuration + def waitForDuration[T](future: CancelableFuture[T], duration: FiniteDuration)( + ifError: => Unit + ): T = { + import java.util.concurrent.TimeoutException + try Await.result(future, duration) + catch { + case t: TimeoutException => ifError; throw t + } + } + + def waitInSeconds[T](future: CancelableFuture[T], seconds: Int)(ifError: => Unit): T = { + waitForDuration(future, FiniteDuration(seconds.toLong, TimeUnit.SECONDS))(ifError) + } + + def waitInMillis[T](future: CancelableFuture[T], ms: Int)(ifError: => Unit): T = { + waitForDuration(future, FiniteDuration(ms.toLong, TimeUnit.MILLISECONDS))(ifError) + } + + override def utestAfterAll(): Unit = afterAll() + override def utestFormatter: Formatter = new Formatter { + override def exceptionMsgColor: Attrs = Attrs.Empty + override def exceptionStackFrameHighlighter( + s: StackTraceElement + ): Boolean = { + s.getClassName.startsWith("bloop.") && + !(s.getClassName.startsWith("bloop.util") || + s.getClassName.startsWith("bloop.testing")) + } + + override def formatWrapWidth: Int = 3000 + override def formatException(x: Throwable, leftIndent: String): Str = + super.formatException(x, "") + } + + case class FlatTest(name: String, thunk: () => Unit) + private val myTests = IndexedSeq.newBuilder[FlatTest] + + def ignore(name: String, label: String = "IGNORED")(fun: => Any): Unit = { + myTests += FlatTest( + utest.ufansi.Color.LightRed(s"$label - $name").toString(), + () => () + ) + } + + def test(name: String)(fun: => Any): Unit = { + myTests += FlatTest(name, () => { fun; () }) + } + + implicit lazy val testScheduler = { + monix.execution.Scheduler.Implicits.global + } + + def testAsync(name: String, maxDuration: Duration = Duration("20s"))( + run: => Unit + ): Unit = { + test(name) { + Await.result(Task { run }.runAsync(testScheduler), maxDuration) + } + } + + def fail(msg: String, stackBump: Int = 0): Nothing = { + val ex = new DiffAssertions.TestFailedException(msg) + ex.setStackTrace(ex.getStackTrace.slice(1 + stackBump, 5 + stackBump)) + throw ex + } + + override def tests: Tests = { + val ts = myTests.result() + val names = Tree("", ts.map(x => Tree(x.name)): _*) + val thunks = new TestCallTree({ + this.beforeAll() + Right(ts.map(x => new TestCallTree(Left(x.thunk())))) + }) + Tests(names, thunks) + } +} diff --git a/src/test/scala/sailgun/SailgunBaseSuite.scala b/src/test/scala/sailgun/SailgunBaseSuite.scala new file mode 100644 index 0000000..9d84a1f --- /dev/null +++ b/src/test/scala/sailgun/SailgunBaseSuite.scala @@ -0,0 +1,260 @@ +package sailgun + +import sailgun.logging.Logger +import sailgun.logging.RecordingLogger +import sailgun.logging.Slf4jAdapter +import sailgun.protocol.Defaults +import sailgun.protocol.Streams +import sailgun.utils.ExitNail +import sailgun.utils.SailgunHeartbeat + +import java.io.PrintStream +import java.nio.file.{Files, Path, Paths} +import java.util.concurrent.{ExecutionException, TimeUnit} +import java.nio.charset.StandardCharsets +import java.io.PipedInputStream +import java.io.PipedOutputStream +import java.io.InputStream +import java.io.OutputStream +import java.io.ByteArrayOutputStream +import java.net.InetAddress +import java.util.concurrent.TimeoutException +import java.util.concurrent.atomic.AtomicBoolean + +import monix.eval.Task +import monix.execution.misc.NonFatal +import monix.execution.Scheduler + +import scala.concurrent.Await +import scala.concurrent.duration.FiniteDuration + +import com.martiansoftware.nailgun.NGListeningAddress +import com.martiansoftware.nailgun.NGConstants +import com.martiansoftware.nailgun.Alias +import com.martiansoftware.nailgun.examples.Echo +import com.martiansoftware.nailgun.examples.HelloWorld +import com.martiansoftware.nailgun.examples.SailgunEcho +import com.martiansoftware.nailgun.{ + SailgunThreadLocalInputStream, + NGServer, + ThreadLocalPrintStream +} + +class SailgunBaseSuite extends BaseSuite { + protected final val TestPort = 8313 + private final val nailgunPool = Scheduler.computation(parallelism = 2) + + def startServer[T](streams: Streams, logger: Logger)( + op: Client => T + ): Task[T] = { + /* + * This code tricks nailgun into thinking it has already set up the streams + * and wrapped them with their own thread local-based wrappers. We do this + * to avoid Nailgun running `System.setIn`, `System.setOut` and + * `System.setErr` which would affect all tests run in the suite and + * effectively hide input/outputs. + */ + val currentIn = System.in + val currentOut = System.out + val currentErr = System.err + + // Some dummy streams that we use initially + val serverIn = new PipedInputStream() + val clientOut = new PipedOutputStream(serverIn) + val clientIn = new PipedInputStream() + val serverOut = new PrintStream(new PipedOutputStream(clientIn)) + val serverErr = new PrintStream(new ByteArrayOutputStream()) + + val localIn = new SailgunThreadLocalInputStream(serverIn) + val localOut = new ThreadLocalPrintStream(serverOut) + val localErr = new ThreadLocalPrintStream(serverOut) + + localIn.init(serverIn) + localOut.init(serverOut) + localErr.init(serverErr) + + System.in.synchronized { + System.setIn(localIn) + System.setOut(localOut) + System.setErr(localErr) + } + + val addr = InetAddress.getLoopbackAddress + val serverIsStarted = scala.concurrent.Promise[Unit]() + val serverIsFinished = scala.concurrent.Promise[Unit]() + val serverLogic = Task { + try { + val server = + prepareTestServer(localIn, localOut, localErr, addr, TestPort, logger) + serverIsStarted.success(()) + server.run() + serverIsFinished.success(()) + } catch { + case monix.execution.misc.NonFatal(t) => + currentErr.println("Error when starting server") + t.printStackTrace(currentErr) + serverIsStarted.failure(t) + serverIsFinished.failure(t) + } finally { + serverOut.flush() + serverErr.flush() + } + } + + val client = new TcpClient(addr, TestPort) + def clientCancel(t: Option[Throwable]) = Task { + serverOut.flush() + serverErr.flush() + + t.foreach(t => logger.trace(t)) + + val code = client.run( + "exit", + new Array(0), + Defaults.cwd, + Defaults.env, + streams, + logger, + new AtomicBoolean(false) + ) + + // Exit on Windows can sometimes return non-successful code even if exit succeeded + if (isWindows) { + if (code != 0) { + logger.debug(s"The status code for exit in Windows was ${code}.") + } + } else { + assert(code == 0) + } + + System.in.synchronized { + System.setIn(currentIn) + System.setOut(currentOut) + System.setErr(currentErr) + } + + } + + val runClientLogic = Task(op(client)) + .doOnFinish(clientCancel(_)) + .doOnCancel(clientCancel(None)) + + val startTrigger = Task.fromFuture(serverIsStarted.future) + val endTrigger = Task.fromFuture(serverIsFinished.future) + val runClient = { + for { + _ <- startTrigger + value <- runClientLogic + _ <- endTrigger + } yield value + } + + Task + .zip2(serverLogic, runClient) + .map(t => t._2) + .timeout(FiniteDuration(5, TimeUnit.SECONDS)) + } + + def prepareTestServer( + in: InputStream, + out: PrintStream, + err: PrintStream, + addr: InetAddress, + port: Int, + logger: Logger + ): NGServer = { + val javaLogger = new Slf4jAdapter(logger) + val address = new NGListeningAddress(addr, port) + val poolSize = NGServer.DEFAULT_SESSIONPOOLSIZE + val heartbeatMs = NGConstants.HEARTBEAT_TIMEOUT_MILLIS.toInt + val server = + new NGServer(address, poolSize, heartbeatMs, in, out, err, javaLogger) + server.setAllowNailsByClassName(false) + val aliases = server.getAliasManager + aliases.addAlias( + new Alias( + "heartbeat", + "Run `Heartbeat` naigun server example.", + classOf[SailgunHeartbeat] + ) + ) + aliases.addAlias( + new Alias( + "echo", + "Run `Echo` naigun server example.", + classOf[SailgunEcho] + ) + ) + aliases.addAlias( + new Alias( + "hello-world", + "Run `HelloWorld` naigun server example.", + classOf[HelloWorld] + ) + ) + aliases.addAlias( + new Alias( + "exit", + "Run `exit` on the nail main defined in this class.", + classOf[ExitNail] + ) + ) + server + } + + /** + * Starts a Nailgun server, creates a sailgun client and executes operations + * with that client. The server is killed when the client exits. + * + * @param streams The user-defined streams. + * @param log The logger instance for the test run. + * @param op A function that will receive the instantiated Client. + * @return The result of executing `op` on the client. + */ + def withRunningServer[T]( + streams: Streams, + logger: Logger + )(op: Client => T): T = { + val f = startServer(streams, logger)(op).runAsync(nailgunPool) + try Await.result(f, FiniteDuration(5, TimeUnit.SECONDS)) + catch { + case e: ExecutionException => throw e.getCause() + case t: Throwable => throw t + } finally f.cancel() + } + + case class TestInputs( + streams: Streams, + logger: RecordingLogger, + stop: AtomicBoolean, + private val client: Client, + private val out: ByteArrayOutputStream + ) { + def run(cmd: String, args: Array[String]): Int = + client.run(cmd, args, Defaults.cwd, Defaults.env, streams, logger, stop) + + def generateResult: String = { + new String(out.toByteArray(), StandardCharsets.UTF_8) + } + } + + val oldErr = System.err + def testSailgun( + testName: String, + in: InputStream = System.in + )(op: TestInputs => Unit): Unit = { + val logger = new RecordingLogger() + test(testName) { + try { + val stop = new AtomicBoolean(false) + val out = new ByteArrayOutputStream() + val streams = Streams(in, out, out) + withRunningServer(streams, logger) { client => + op(TestInputs(streams, logger, stop, client, out)) + } + } catch { + case t: TimeoutException => logger.dump(oldErr); throw t + } + } + } +} diff --git a/src/test/scala/sailgun/SailgunSpec.scala b/src/test/scala/sailgun/SailgunSpec.scala new file mode 100644 index 0000000..3ce49c5 --- /dev/null +++ b/src/test/scala/sailgun/SailgunSpec.scala @@ -0,0 +1,70 @@ +package sailgun + +import sailgun.protocol.Streams +import sailgun.logging.RecordingLogger +import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets +import java.io.PipedOutputStream +import java.io.PipedInputStream +import java.io.PrintStream +import java.util.concurrent.TimeUnit +import monix.execution.Cancelable +import scala.concurrent.duration.FiniteDuration + +object SailgunSpec extends SailgunBaseSuite { + testSailgun("hello world works") { inputs => + val code = inputs.run("hello-world", new Array(0)) + assert(code == 0) + assertNoDiff( + inputs.generateResult, + "Hello, world!" + ) + } + + testSailgun("heartbeat works") { inputs => + val args = Array("2000") + val code = inputs.run("heartbeat", args) + assert(code == 0) + // In 2000ms, we can receive 3 'H' + assertNoDiff( + inputs.generateResult, + "HHH" + ) + } + + val echoStdout = new PipedOutputStream() + val echoStdin = new PipedInputStream(echoStdout) + testSailgun("echo works (via stdin)", echoStdin) { inputs => + val ps = new PrintStream(echoStdout) + + testScheduler.scheduleOnce(FiniteDuration(500, TimeUnit.MILLISECONDS)) { + ps.println("Hello, world!") + Thread.sleep(10) + ps.println("I am echo") + Thread.sleep(10) + ps.println("exit") + } + + val code = inputs.run("echo", new Array(0)) + assert(code == 0) + assertNoDiff( + inputs.generateResult, + "Hello, world!I am echo" + ) + } + + val echoStdout2 = new PipedOutputStream() + val echoStdin2 = new PipedInputStream(echoStdout2) + testSailgun("cancellation of echo works", echoStdin2) { inputs => + testScheduler.scheduleOnce(FiniteDuration(300, TimeUnit.MILLISECONDS)) { + inputs.stop.set(true) + } + + // Ignore return code, all we care is that we return + inputs.run("echo", new Array(0)) + assertNoDiff( + inputs.generateResult, + "" + ) + } +} diff --git a/src/test/scala/sailgun/logging/RecordingLogger.scala b/src/test/scala/sailgun/logging/RecordingLogger.scala new file mode 100644 index 0000000..ff05fe5 --- /dev/null +++ b/src/test/scala/sailgun/logging/RecordingLogger.scala @@ -0,0 +1,79 @@ +package sailgun.logging + +import java.io.PrintStream +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters.asScalaIteratorConverter + +class RecordingLogger( + debug: Boolean = false, + debugOut: Option[PrintStream] = None +) extends Logger { + override val isVerbose: Boolean = true + override val name: String = "SailgunRecordingLogger" + private[this] val messages = new ConcurrentLinkedQueue[(String, String)] + + def clear(): Unit = messages.clear() + def debugs: List[String] = getMessagesAt(Some("debug")) + def infos: List[String] = getMessagesAt(Some("info")) + def warnings: List[String] = getMessagesAt(Some("warn")) + def errors: List[String] = getMessagesAt(Some("error")) + + def getMessagesAt(level: Option[String]): List[String] = + getMessages(level).map(_._2) + + def getMessages(): List[(String, String)] = getMessages(None) + private def getMessages(level: Option[String]): List[(String, String)] = { + val initialMsgs = messages.iterator.asScala + val msgs = level match { + case Some(level) => initialMsgs.filter(_._1 == level) + case None => initialMsgs + } + + msgs.map { + // Remove trailing '\r' so that we don't have to special case for Windows + case (category, msg) => (category, msg.stripSuffix("\r")) + }.toList + } + + def add(key: String, value: String): Unit = { + if (debug) { + debugOut match { + case Some(o) => o.println(s"[$key] $value") + case None => println(s"[$key] $value") + } + } + + messages.add((key, value)) + () + } + + override def debug(msg: String): Unit = add("debug", msg) + override def info(msg: String): Unit = add("info", msg) + override def error(msg: String): Unit = add("error", msg) + override def warn(msg: String): Unit = add("warn", msg) + private def trace(msg: String): Unit = add("trace", msg) + override def trace(ex: Throwable): Unit = { + ex.getStackTrace.foreach(ste => trace(ste.toString)) + Option(ex.getCause).foreach { cause => + trace("Caused by:") + trace(cause) + } + } + + def dump(out: PrintStream = System.out): Unit = { + out.println { + s"""Logger contains the following messages: + |${getMessages + .map(s => s"[${s._1}] ${s._2}") + .mkString("\n ", "\n ", "\n")} + """.stripMargin + } + } + + def render: String = { + getMessages() + .map { case (level, msg) => s"[${level}] ${msg}" } + .mkString(System.lineSeparator()) + } +} diff --git a/src/test/scala/sailgun/logging/Slf4jAdapter.scala b/src/test/scala/sailgun/logging/Slf4jAdapter.scala new file mode 100644 index 0000000..c17d7d2 --- /dev/null +++ b/src/test/scala/sailgun/logging/Slf4jAdapter.scala @@ -0,0 +1,177 @@ +package sailgun.logging + +import org.slf4j.{Marker, Logger => Slf4jLogger} + +/** + * Defines a slf4j-compliant logger wrapping Bloop logging utils. + * + * This slf4j interface is necessary to be compatible with third-party libraries + * like lsp4s. It only intends to cover the basic functionality and it does not + * support slf4j markers. + * + * @param logger A logger interface. + */ +final class Slf4jAdapter[L <: Logger](logger: L) extends Slf4jLogger { + def underlying: L = logger + override def getName: String = logger.name + override def debug(msg: String): Unit = logger.debug(msg) + override def debug(format: String, arg: scala.Any): Unit = + logger.debug(arg.toString) + + override def debug(msg: String, t: Throwable): Unit = logger.debug(msg) + override def debug(marker: Marker, msg: String): Unit = logger.debug(msg) + override def debug(marker: Marker, msg: String, t: Throwable): Unit = + logger.debug(msg) + override def debug(marker: Marker, format: String, arg: scala.Any): Unit = + logger.debug(arg.toString) + + override def debug(format: String, arguments: AnyRef*): Unit = + arguments.foreach(a => logger.debug(a.toString)) + override def debug(marker: Marker, format: String, arguments: AnyRef*): Unit = + arguments.foreach(a => logger.debug(a.toString)) + + override def debug(format: String, arg1: scala.Any, arg2: scala.Any): Unit = { + logger.debug(arg1.toString); logger.debug(arg2.toString) + } + + override def debug( + marker: Marker, + format: String, + arg1: scala.Any, + arg2: scala.Any + ): Unit = { + logger.debug(arg1.toString); logger.debug(arg2.toString) + } + + override def error(msg: String): Unit = logger.error(msg) + override def error(format: String, arg: scala.Any): Unit = + logger.error(arg.toString) + override def error(msg: String, t: Throwable): Unit = logger.error(msg) + override def error(marker: Marker, msg: String): Unit = logger.error(msg) + override def error(marker: Marker, format: String, arg: scala.Any): Unit = + logger.error(arg.toString) + override def error(marker: Marker, msg: String, t: Throwable): Unit = + logger.error(msg) + + override def error( + marker: Marker, + format: String, + arg1: scala.Any, + arg2: scala.Any + ): Unit = { + logger.error(arg1.toString); logger.error(arg2.toString) + } + + override def error(format: String, arg1: scala.Any, arg2: scala.Any): Unit = { + logger.error(arg1.toString); logger.error(arg2.toString) + } + + override def error(format: String, arguments: AnyRef*): Unit = + arguments.foreach(a => logger.error(a.toString)) + override def error(marker: Marker, format: String, arguments: AnyRef*): Unit = + arguments.foreach(a => logger.error(a.toString)) + + override def warn(msg: String): Unit = logger.warn(msg) + override def warn(format: String, arg: scala.Any): Unit = + logger.warn(arg.toString) + override def warn(msg: String, t: Throwable): Unit = logger.warn(msg) + override def warn(marker: Marker, msg: String): Unit = logger.warn(msg) + override def warn(marker: Marker, msg: String, t: Throwable): Unit = + logger.warn(msg) + + override def warn(marker: Marker, format: String, arg: scala.Any): Unit = + logger.warn(arg.toString) + + override def warn(format: String, arguments: AnyRef*): Unit = + arguments.foreach(a => logger.warn(a.toString)) + override def warn(marker: Marker, format: String, arguments: AnyRef*): Unit = + arguments.foreach(a => logger.warn(a.toString)) + + override def warn(format: String, arg1: scala.Any, arg2: scala.Any): Unit = { + logger.warn(arg1.toString); logger.warn(arg2.toString) + } + + override def warn( + marker: Marker, + format: String, + arg1: scala.Any, + arg2: scala.Any + ): Unit = { + logger.warn(arg1.toString); logger.warn(arg2.toString) + } + + override def trace(msg: String): Unit = logger.debug(msg) + override def trace(format: String, arg: scala.Any): Unit = + logger.debug(arg.toString) + override def trace(marker: Marker, msg: String): Unit = logger.debug(msg) + override def trace(marker: Marker, format: String, arg: scala.Any): Unit = + logger.debug(arg.toString) + override def trace(format: String, arguments: AnyRef*): Unit = + arguments.foreach(a => logger.debug(a.toString)) + override def trace(marker: Marker, format: String, argArray: AnyRef*): Unit = + argArray.foreach(a => logger.debug(a.toString)) + + override def trace(msg: String, t: Throwable): Unit = { + logger.debug(msg); logger.trace(t) + } + + override def trace(marker: Marker, msg: String, t: Throwable): Unit = { + logger.debug(msg); logger.trace(t) + } + + override def trace(format: String, arg1: scala.Any, arg2: scala.Any): Unit = { + logger.debug(arg1.toString); logger.debug(arg2.toString) + } + + override def trace( + marker: Marker, + format: String, + arg1: scala.Any, + arg2: scala.Any + ): Unit = { + logger.debug(arg1.toString); logger.debug(arg2.toString) + } + + override def isWarnEnabled: Boolean = true + override def isWarnEnabled(marker: Marker): Boolean = true + + override def isInfoEnabled: Boolean = true + override def isInfoEnabled(marker: Marker): Boolean = true + + override def isErrorEnabled: Boolean = true + override def isErrorEnabled(marker: Marker): Boolean = true + + override def isTraceEnabled: Boolean = logger.isVerbose + override def isTraceEnabled(marker: Marker): Boolean = logger.isVerbose + + override def isDebugEnabled: Boolean = logger.isVerbose + override def isDebugEnabled(marker: Marker): Boolean = logger.isVerbose + + override def info(msg: String): Unit = logger.info(msg) + override def info(format: String, arg: scala.Any): Unit = + logger.info(arg.toString) + + override def info(format: String, arguments: AnyRef*): Unit = + arguments.foreach(a => logger.info(a.toString)) + override def info(marker: Marker, format: String, arguments: AnyRef*): Unit = + arguments.foreach(a => logger.info(a.toString)) + + override def info(msg: String, t: Throwable): Unit = logger.info(msg) + override def info(marker: Marker, msg: String): Unit = logger.info(msg) + override def info(marker: Marker, format: String, arg: scala.Any): Unit = + logger.info(arg.toString) + override def info(marker: Marker, msg: String, t: Throwable): Unit = + logger.info(msg) + + override def info(format: String, arg1: scala.Any, arg2: scala.Any): Unit = { + logger.info(arg1.toString); logger.info(arg2.toString) + } + override def info( + marker: Marker, + format: String, + arg1: scala.Any, + arg2: scala.Any + ): Unit = { + logger.info(arg1.toString); logger.info(arg2.toString) + } +} diff --git a/src/test/scala/sailgun/utils/Diff.scala b/src/test/scala/sailgun/utils/Diff.scala new file mode 100644 index 0000000..ace9957 --- /dev/null +++ b/src/test/scala/sailgun/utils/Diff.scala @@ -0,0 +1,36 @@ +package sailgun.utils + +object Diff { + def unifiedDiff( + original: String, + revised: String, + obtained: String, + expected: String + ): String = + compareContents( + splitIntoLines(original), + splitIntoLines(revised), + obtained, + expected + ) + + private def splitIntoLines(string: String): Seq[String] = + string.trim.replace("\r\n", "\n").split("\n") + + private def compareContents( + original: Seq[String], + revised: Seq[String], + obtained: String, + expected: String + ): String = { + import scala.collection.JavaConverters._ + val diff = difflib.DiffUtils.diff(original.asJava, revised.asJava) + if (diff.getDeltas.isEmpty) "" + else { + difflib.DiffUtils + .generateUnifiedDiff(obtained, expected, original.asJava, diff, 1) + .asScala + .mkString("\n") + } + } +} diff --git a/src/test/scala/sailgun/utils/DiffAssertions.scala b/src/test/scala/sailgun/utils/DiffAssertions.scala new file mode 100644 index 0000000..ec6df7d --- /dev/null +++ b/src/test/scala/sailgun/utils/DiffAssertions.scala @@ -0,0 +1,148 @@ +package sailgun.utils + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal +import utest.ufansi.Color + +object DiffAssertions { + class TestFailedException(msg: String) extends Exception(msg) + def assertNoDiffOrPrintObtained( + obtained: String, + expected: String, + obtainedTitle: String, + expectedTitle: String + )(implicit source: sourcecode.Line): Unit = { + orPrintObtained( + () => { assertNoDiff(obtained, expected, obtainedTitle, expectedTitle); () }, + obtained + ) + } + + def assertNoDiffOrPrintExpected( + obtained: String, + expected: String, + obtainedTitle: String, + expectedTitle: String, + print: Boolean = true + )( + implicit source: sourcecode.Line + ): Boolean = { + try assertNoDiff(obtained, expected, obtainedTitle, expectedTitle) + catch { + case ex: Exception => + if (print) { + obtained.linesIterator.toList match { + case head +: tail => + val b = new StringBuilder() + b.++=(" \"\"\"|" + head).++=(System.lineSeparator()) + tail.foreach { line => + b.++=(" |") + .++=(line) + .++=(System.lineSeparator()) + } + b.++=(" |\"\"\".stripMargin").++=(System.lineSeparator()) + println(b.mkString) + case head +: Nil => + println(head) + case Nil => + println("obtained is empty") + } + } + throw ex + } + } + + def assertNoDiff( + obtained: String, + expected: String, + obtainedTitle: String, + expectedTitle: String + )(implicit source: sourcecode.Line): Boolean = colored { + if (obtained.isEmpty && !expected.isEmpty) fail("Obtained empty output!") + val result = Diff.unifiedDiff(obtained, expected, obtainedTitle, expectedTitle) + if (result.isEmpty) true + else { + throw new TestFailedException( + error2message( + obtained, + expected, + obtainedTitle, + expectedTitle + ) + ) + } + } + + private def error2message( + obtained: String, + expected: String, + obtainedTitle: String, + expectedTitle: String + ): String = { + def header[T](t: T): String = { + val line = s"=" * (t.toString.length + 3) + s"$line\n=> $t\n$line" + } + def stripTrailingWhitespace(str: String): String = + str.replaceAll(" \n", "∙\n") + val sb = new StringBuilder + if (obtained.length < 1000) { + sb.append( + s"""#${header("Obtained")} + #${stripTrailingWhitespace(obtained)} + # + #""".stripMargin('#') + ) + } + sb.append( + s"""#${header("Diff")} + #${stripTrailingWhitespace( + Diff.unifiedDiff(obtained, expected, obtainedTitle, expectedTitle) + )}""" + .stripMargin('#') + ) + sb.toString() + } + + def colored[T]( + thunk: => T + )(implicit filename: sourcecode.File, line: sourcecode.Line): T = { + try { + thunk + } catch { + case NonFatal(e) => + val message = e.getMessage.linesIterator + .map { line => + if (line.startsWith("+")) Color.Green(line) + else if (line.startsWith("-")) Color.LightRed(line) + else Color.Reset(line) + } + .mkString("\n") + val location = s"failed assertion at ${filename.value}:${line.value}\n" + throw new TestFailedException(location + message) + } + } + + def orPrintObtained(thunk: () => Unit, obtained: String): Unit = { + try thunk() + catch { + case ex: Exception => + obtained.linesIterator.toList match { + case head +: tail => + println(" \"\"\"|" + head) + tail.foreach(line => println(" |" + line)) + case head +: Nil => + println(head) + case Nil => + println("obtained is empty") + } + throw ex + } + } + + def fail(msg: String, stackBump: Int = 0): Nothing = { + val ex = new DiffAssertions.TestFailedException(msg) + ex.setStackTrace(ex.getStackTrace.slice(1 + stackBump, 2 + stackBump)) + throw ex + } +} diff --git a/src/test/scala/sailgun/utils/ExitNail.scala b/src/test/scala/sailgun/utils/ExitNail.scala new file mode 100644 index 0000000..8c55b2d --- /dev/null +++ b/src/test/scala/sailgun/utils/ExitNail.scala @@ -0,0 +1,22 @@ +package sailgun.utils + +import com.martiansoftware.nailgun.NGContext + +class ExitNail +object ExitNail { + def nailMain(ngContext: NGContext): Unit = { + val server = ngContext.getNGServer + import java.util.concurrent.ForkJoinPool + + ForkJoinPool + .commonPool() + .submit(new Runnable { + override def run(): Unit = { + server.shutdown(false) + } + }) + + () + } + +} diff --git a/src/test/scala/sailgun/utils/SailgunThreadLocalInputStream.scala b/src/test/scala/sailgun/utils/SailgunThreadLocalInputStream.scala new file mode 100644 index 0000000..84d47d7 --- /dev/null +++ b/src/test/scala/sailgun/utils/SailgunThreadLocalInputStream.scala @@ -0,0 +1,9 @@ +package com.martiansoftware.nailgun + +import java.io.InputStream + +final class SailgunThreadLocalInputStream(stream: InputStream) + extends ThreadLocalInputStream(stream) { + override def init(streamForCurrentThread: InputStream): Unit = + super.init(streamForCurrentThread) +}