diff --git a/src/main/scala/zio/process/Command.scala b/src/main/scala/zio/process/Command.scala index 6b0ca90a..6e0d2ad7 100644 --- a/src/main/scala/zio/process/Command.scala +++ b/src/main/scala/zio/process/Command.scala @@ -15,13 +15,12 @@ */ package zio.process -import java.io.File -import java.lang.ProcessBuilder.Redirect -import java.nio.charset.Charset - import zio._ import zio.stream.{ ZSink, ZStream } +import java.io.{ File, OutputStream } +import java.lang.ProcessBuilder.Redirect +import java.nio.charset.Charset import scala.jdk.CollectionConverters._ sealed trait Command { @@ -54,7 +53,7 @@ sealed trait Command { * Inherit standard input, standard output, and standard error. */ def inheritIO: Command = - stdin(ProcessInput.inherit).stdout(ProcessOutput.Inherit).stderr(ProcessOutput.Inherit) + stdin(ProcessInput.Inherit).stdout(ProcessOutput.Inherit).stderr(ProcessOutput.Inherit) /** * Runs the command returning the output as a list of lines (default encoding of UTF-8). @@ -116,8 +115,9 @@ sealed trait Command { } c.stdin match { - case ProcessInput(None) => builder.redirectInput(Redirect.INHERIT) - case ProcessInput(Some(_)) => () + case ProcessInput.Inherit => builder.redirectInput(Redirect.INHERIT) + case ProcessInput.Pipe => builder.redirectInput(Redirect.PIPE) + case ProcessInput.FromStream(_, _) => () } c.stdout match { @@ -141,12 +141,14 @@ sealed trait Command { case CommandThrowable.IOError(e) => e } _ <- c.stdin match { - case ProcessInput(None) => ZIO.unit - case ProcessInput(Some(input)) => + case ProcessInput.Inherit | ProcessInput.Pipe => ZIO.unit + case ProcessInput.FromStream(input, flushChunks) => for { outputStream <- process.execute(_.getOutputStream) + sink = if (flushChunks) fromOutputStreamFlushChunksEagerly(outputStream) + else ZSink.fromOutputStream(outputStream) _ <- input - .run(ZSink.fromOutputStream(outputStream)) + .run(sink) .ensuring(UIO(outputStream.close())) .forkDaemon } yield () @@ -157,11 +159,16 @@ sealed trait Command { c.flatten match { case chunk if chunk.length == 1 => chunk.head.run case chunk => + val flushChunksEagerly = chunk.head.stdin match { + case ProcessInput.FromStream(_, eager) => eager + case ProcessInput.Inherit | ProcessInput.Pipe => false + } + val stream = chunk.tail.init.foldLeft(chunk.head.stream) { case (s, command) => - command.stdin(ProcessInput.fromStream(s)).stream + command.stdin(ProcessInput.fromStream(s, flushChunksEagerly)).stream } - chunk.last.stdin(ProcessInput.fromStream(stream)).run + chunk.last.stdin(ProcessInput.fromStream(stream, flushChunksEagerly)).run } } @@ -243,6 +250,14 @@ sealed trait Command { def <<(input: String): Command = stdin(ProcessInput.fromUTF8String(input)) + private def fromOutputStreamFlushChunksEagerly(os: OutputStream): ZSink[Any, Throwable, Byte, Nothing, Unit] = + ZSink.foreachChunk { (chunk: Chunk[Byte]) => + ZIO.attemptBlockingInterrupt { + os.write(chunk.toArray) + os.flush() + } + } + } object Command { @@ -267,7 +282,7 @@ object Command { NonEmptyChunk(processName, args: _*), Map.empty, Option.empty[File], - ProcessInput.inherit, + ProcessInput.Inherit, ProcessOutput.Pipe, ProcessOutput.Pipe, redirectErrorStream = false diff --git a/src/main/scala/zio/process/ProcessInput.scala b/src/main/scala/zio/process/ProcessInput.scala index 65ee4669..bbc784ac 100644 --- a/src/main/scala/zio/process/ProcessInput.scala +++ b/src/main/scala/zio/process/ProcessInput.scala @@ -15,38 +15,76 @@ */ package zio.process -import java.io.ByteArrayInputStream -import java.nio.charset.{ Charset, StandardCharsets } +import zio.stream.ZStream +import zio.{ Chunk, Queue } -import zio.Chunk -import zio.stream.{ Stream, ZStream } +import java.io.{ ByteArrayInputStream, File } +import java.nio.charset.{ Charset, StandardCharsets } +import java.nio.file.Path -final case class ProcessInput(source: Option[ZStream[Any, CommandError, Byte]]) +sealed trait ProcessInput object ProcessInput { - val inherit: ProcessInput = ProcessInput(None) + final case class FromStream(stream: ZStream[Any, CommandError, Byte], flushChunksEagerly: Boolean) + extends ProcessInput + + case object Inherit extends ProcessInput + case object Pipe extends ProcessInput /** * Returns a ProcessInput from an array of bytes. */ def fromByteArray(bytes: Array[Byte]): ProcessInput = - ProcessInput(Some(Stream.fromInputStream(new ByteArrayInputStream(bytes)).mapError(CommandError.IOError.apply))) + ProcessInput.FromStream( + ZStream.fromInputStream(new ByteArrayInputStream(bytes)).mapError(CommandError.IOError.apply), + flushChunksEagerly = false + ) + + /** + * Returns a ProcessInput from a file. + */ + def fromFile(file: File, chunkSize: Int = ZStream.DefaultChunkSize): ProcessInput = + ProcessInput.FromStream( + ZStream.fromFile(file, chunkSize).refineOrDie { case CommandThrowable.IOError(e) => e }, + flushChunksEagerly = false + ) + + /** + * Returns a ProcessInput from a path to a file. + */ + def fromPath(path: Path, chunkSize: Int = ZStream.DefaultChunkSize): ProcessInput = + ProcessInput.FromStream( + ZStream.fromPath(path, chunkSize).refineOrDie { case CommandThrowable.IOError(e) => e }, + flushChunksEagerly = false + ) + + /** + * Returns a ProcessInput from a queue of bytes to send to the process in a controlled manner. + */ + def fromQueue(queue: Queue[Chunk[Byte]]): ProcessInput = + ProcessInput.fromStream(ZStream.fromQueue(queue).flattenChunks, flushChunksEagerly = true) /** * Returns a ProcessInput from a stream of bytes. + * + * You may want to set `flushChunksEagerly` to true when doing back-and-forth communication with a process such as + * interacting with a repl (flushing the command to send so that you can receive a response immediately). */ - def fromStream(stream: ZStream[Any, CommandError, Byte]): ProcessInput = - ProcessInput(Some(stream)) + def fromStream(stream: ZStream[Any, CommandError, Byte], flushChunksEagerly: Boolean = false): ProcessInput = + ProcessInput.FromStream(stream, flushChunksEagerly) /** * Returns a ProcessInput from a String with the given charset. */ def fromString(text: String, charset: Charset): ProcessInput = - ProcessInput(Some(ZStream.fromChunks(Chunk.fromArray(text.getBytes(charset))))) + ProcessInput.FromStream(ZStream.fromChunks(Chunk.fromArray(text.getBytes(charset))), flushChunksEagerly = false) /** * Returns a ProcessInput from a UTF-8 String. */ def fromUTF8String(text: String): ProcessInput = - ProcessInput(Some(ZStream.fromChunks(Chunk.fromArray(text.getBytes(StandardCharsets.UTF_8))))) + ProcessInput.FromStream( + ZStream.fromChunks(Chunk.fromArray(text.getBytes(StandardCharsets.UTF_8))), + flushChunksEagerly = false + ) } diff --git a/src/test/bash/stdin-echo.sh b/src/test/bash/stdin-echo.sh new file mode 100755 index 00000000..0de47dae --- /dev/null +++ b/src/test/bash/stdin-echo.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +while read line +do + echo "$line" +done \ No newline at end of file diff --git a/src/test/scala/zio/process/CommandSpec.scala b/src/test/scala/zio/process/CommandSpec.scala index 7ebf7937..642a429c 100644 --- a/src/test/scala/zio/process/CommandSpec.scala +++ b/src/test/scala/zio/process/CommandSpec.scala @@ -3,7 +3,7 @@ package zio.process import zio.stream.ZPipeline import zio.test.Assertion._ import zio.test._ -import zio.{ durationInt, Chunk, ExitCode, ZIO } +import zio.{ durationInt, Chunk, ExitCode, Queue, ZIO } import java.io.File import java.nio.charset.StandardCharsets @@ -43,7 +43,7 @@ object CommandSpec extends ZIOProcessBaseSpec { }, test("accept streaming stdin") { val stream = Command("echo", "-n", "a", "b", "c").stream - val zio = Command("cat").stdin(ProcessInput.fromStream(stream)).string + val zio = Command("cat").stdin(ProcessInput.fromStream(stream, flushChunksEagerly = false)).string assertM(zio)(equalTo("a b c")) }, @@ -52,6 +52,11 @@ object CommandSpec extends ZIOProcessBaseSpec { assertM(zio)(equalTo("piped in")) }, + test("accept file stdin") { + for { + lines <- Command("cat").stdin(ProcessInput.fromFile(new File("src/test/bash/echo-repeat.sh"))).lines + } yield assertTrue(lines.head == "#!/bin/bash") + }, test("support different encodings") { val zio = Command("cat") @@ -136,6 +141,19 @@ object CommandSpec extends ZIOProcessBaseSpec { exit <- Command("ls").workingDirectory(new File("/some/bad/path")).lines.exit } yield assert(exit)(fails(isSubtype[CommandError.WorkingDirectoryMissing](anything))) }, + test("connect to a repl-like process and flush the chunks eagerly and get responses right away") { + for { + commandQueue <- Queue.unbounded[Chunk[Byte]] + process <- Command("./stdin-echo.sh") + .workingDirectory(new File("src/test/bash")) + .stdin(ProcessInput.fromQueue(commandQueue)) + .run + _ <- commandQueue.offer(Chunk.fromArray("line1\nline2\n".getBytes(StandardCharsets.UTF_8))) + _ <- commandQueue.offer(Chunk.fromArray("line3\n".getBytes(StandardCharsets.UTF_8))) + lines <- process.stdout.linesStream.take(3).runCollect + _ <- process.kill + } yield assertTrue(lines == Chunk("line1", "line2", "line3")) + }, test("kill only kills parent process") { for { process <- Command("./sample-parent.sh").workingDirectory(new File("src/test/bash/kill-test")).run