Skip to content

Commit

Permalink
Backport ProcessInput refactor (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
reibitto authored Feb 9, 2022
1 parent 407bda3 commit 3880459
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 24 deletions.
1 change: 1 addition & 0 deletions docs/overview/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Here's list of contents available:

- **[Basics](basics.md)** — Creating a description of a command and transforming its output
- **[Piping](piping.md)** — Creating a pipeline of commands
- **[Interactive Processes](interactive_processes.md)** — Communicating with an interactive process
- **[Other](other.md)** — Miscellaneous operations such as settings the working direction, inheriting I/O, etc.

## Installation
Expand Down
30 changes: 30 additions & 0 deletions docs/overview/interactive_processes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
id: overview_interactive_processes
title: "Interactive Processes"
---

Sometimes you want to interact with a process in a back-and-forth manner by sending requests to the process and receiving responses back. For example, interacting with a repl-like process like `node -i`, `python -i`, etc. or an ssh server.

Here is an example of communicating with an interactive Python shell:

```scala mdoc:invisible
import zio._
import zio.process._
import java.nio.charset.StandardCharsets
```

```scala mdoc:silent
for {
commandQueue <- Queue.unbounded[Chunk[Byte]]
process <- Command("python", "-qi").stdin(ProcessInput.fromQueue(commandQueue)).run
_ <- process.stdout.linesStream.foreach { response =>
ZIO.debug(s"Response from REPL: $response")
}.forkDaemon
_ <- commandQueue.offer(Chunk.fromArray("1+1\n".getBytes(StandardCharsets.UTF_8)))
_ <- commandQueue.offer(Chunk.fromArray("2**8\n".getBytes(StandardCharsets.UTF_8)))
_ <- commandQueue.offer(Chunk.fromArray("import random\nrandom.randint(1, 100)\n".getBytes(StandardCharsets.UTF_8)))
_ <- commandQueue.offer(Chunk.fromArray("exit()\n".getBytes(StandardCharsets.UTF_8)))
} yield ()
```

You would probably want to create a helper for the repeated code, but this just a minimal example to help get you started.
37 changes: 26 additions & 11 deletions src/main/scala/zio/process/Command.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
*/
package zio.process

import java.io.File
import java.io.{File, OutputStream}
import java.lang.ProcessBuilder.Redirect
import java.nio.charset.Charset

import zio._
import zio.blocking.Blocking
import zio.stream.{ZSink, ZStream}
Expand Down Expand Up @@ -55,7 +54,7 @@ sealed trait Command {
* Inherit standard input, standard output, and standard error.
*/
def inheritIO: Command =
stdin(ProcessInput.inherit).stdout(ProcessOutput.Inherit).stderr(ProcessOutput.Inherit)
stdin(ProcessInput.Inherit).stdout(ProcessOutput.Inherit).stderr(ProcessOutput.Inherit)

/**
* Runs the command returning the output as a list of lines (default encoding of UTF-8).
Expand Down Expand Up @@ -117,8 +116,9 @@ sealed trait Command {
}

c.stdin match {
case ProcessInput(None) => builder.redirectInput(Redirect.INHERIT)
case ProcessInput(Some(_)) => ()
case ProcessInput.Inherit => builder.redirectInput(Redirect.INHERIT)
case ProcessInput.Pipe => builder.redirectInput(Redirect.PIPE)
case ProcessInput.FromStream(_, _) => ()
}

c.stdout match {
Expand All @@ -142,12 +142,14 @@ sealed trait Command {
case CommandThrowable.IOError(e) => e
}
_ <- c.stdin match {
case ProcessInput(None) => ZIO.unit
case ProcessInput(Some(input)) =>
case ProcessInput.Inherit | ProcessInput.Pipe => ZIO.unit
case ProcessInput.FromStream(input, flushChunks) =>
for {
outputStream <- process.execute(_.getOutputStream)
sink = if (flushChunks) fromOutputStreamFlushChunksEagerly(outputStream)
else ZSink.fromOutputStream(outputStream)
_ <- input
.run(ZSink.fromOutputStream(outputStream))
.run(sink)
.ensuring(UIO(outputStream.close()))
.forkDaemon
} yield ()
Expand All @@ -158,11 +160,16 @@ sealed trait Command {
c.flatten match {
case chunk if chunk.length == 1 => chunk.head.run
case chunk =>
val flushChunksEagerly = chunk.head.stdin match {
case ProcessInput.FromStream(_, eager) => eager
case ProcessInput.Inherit | ProcessInput.Pipe => false
}

val stream = chunk.tail.init.foldLeft(chunk.head.stream) { case (s, command) =>
command.stdin(ProcessInput.fromStream(s)).stream
command.stdin(ProcessInput.fromStream(s, flushChunksEagerly)).stream
}

chunk.last.stdin(ProcessInput.fromStream(stream)).run
chunk.last.stdin(ProcessInput.fromStream(stream, flushChunksEagerly)).run
}
}

Expand Down Expand Up @@ -244,6 +251,14 @@ sealed trait Command {
def <<(input: String): Command =
stdin(ProcessInput.fromUTF8String(input))

private def fromOutputStreamFlushChunksEagerly(os: OutputStream): ZSink[Blocking, Throwable, Byte, Nothing, Unit] =
ZSink.foreachChunk { (chunk: Chunk[Byte]) =>
zio.blocking.effectBlockingInterrupt {
os.write(chunk.toArray)
os.flush()
}
}

}

object Command {
Expand All @@ -268,7 +283,7 @@ object Command {
NonEmptyChunk(processName, args: _*),
Map.empty,
Option.empty[File],
ProcessInput.inherit,
ProcessInput.Inherit,
ProcessOutput.Pipe,
ProcessOutput.Pipe,
redirectErrorStream = false
Expand Down
62 changes: 50 additions & 12 deletions src/main/scala/zio/process/ProcessInput.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,77 @@
*/
package zio.process

import java.io.ByteArrayInputStream
import java.nio.charset.{Charset, StandardCharsets}

import zio.Chunk
import zio.blocking.Blocking
import zio.stream.{Stream, ZStream}
import zio.stream.ZStream
import zio.{Chunk, Queue}

import java.io.{ByteArrayInputStream, File}
import java.nio.charset.{Charset, StandardCharsets}
import java.nio.file.Path

final case class ProcessInput(source: Option[ZStream[Blocking, CommandError, Byte]])
sealed trait ProcessInput

object ProcessInput {
val inherit: ProcessInput = ProcessInput(None)
final case class FromStream(stream: ZStream[Blocking, CommandError, Byte], flushChunksEagerly: Boolean)
extends ProcessInput

case object Inherit extends ProcessInput
case object Pipe extends ProcessInput

/**
* Returns a ProcessInput from an array of bytes.
*/
def fromByteArray(bytes: Array[Byte]): ProcessInput =
ProcessInput(Some(Stream.fromInputStream(new ByteArrayInputStream(bytes)).mapError(CommandError.IOError.apply)))
ProcessInput.FromStream(
ZStream.fromInputStream(new ByteArrayInputStream(bytes)).mapError(CommandError.IOError.apply),
flushChunksEagerly = false
)

/**
* Returns a ProcessInput from a file.
*/
def fromFile(file: File, chunkSize: Int = ZStream.DefaultChunkSize): ProcessInput =
ProcessInput.FromStream(
ZStream.fromFile(file.toPath, chunkSize).refineOrDie { case CommandThrowable.IOError(e) => e },
flushChunksEagerly = false
)

/**
* Returns a ProcessInput from a path to a file.
*/
def fromPath(path: Path, chunkSize: Int = ZStream.DefaultChunkSize): ProcessInput =
ProcessInput.FromStream(
ZStream.fromFile(path, chunkSize).refineOrDie { case CommandThrowable.IOError(e) => e },
flushChunksEagerly = false
)

/**
* Returns a ProcessInput from a queue of bytes to send to the process in a controlled manner.
*/
def fromQueue(queue: Queue[Chunk[Byte]]): ProcessInput =
ProcessInput.fromStream(ZStream.fromQueue(queue).flattenChunks, flushChunksEagerly = true)

/**
* Returns a ProcessInput from a stream of bytes.
*
* You may want to set `flushChunksEagerly` to true when doing back-and-forth communication with a process such as
* interacting with a repl (flushing the command to send so that you can receive a response immediately).
*/
def fromStream(stream: ZStream[Blocking, CommandError, Byte]): ProcessInput =
ProcessInput(Some(stream))
def fromStream(stream: ZStream[Blocking, CommandError, Byte], flushChunksEagerly: Boolean = false): ProcessInput =
ProcessInput.FromStream(stream, flushChunksEagerly)

/**
* Returns a ProcessInput from a String with the given charset.
*/
def fromString(text: String, charset: Charset): ProcessInput =
ProcessInput(Some(ZStream.fromChunks(Chunk.fromArray(text.getBytes(charset)))))
ProcessInput.FromStream(ZStream.fromChunks(Chunk.fromArray(text.getBytes(charset))), flushChunksEagerly = false)

/**
* Returns a ProcessInput from a UTF-8 String.
*/
def fromUTF8String(text: String): ProcessInput =
ProcessInput(Some(ZStream.fromChunks(Chunk.fromArray(text.getBytes(StandardCharsets.UTF_8)))))
ProcessInput.FromStream(
ZStream.fromChunks(Chunk.fromArray(text.getBytes(StandardCharsets.UTF_8))),
flushChunksEagerly = false
)
}
6 changes: 6 additions & 0 deletions src/test/bash/stdin-echo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash

while read line
do
echo "$line"
done
20 changes: 19 additions & 1 deletion src/test/scala/zio/process/CommandSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import zio.stream.ZTransducer
import zio.test.Assertion._
import zio.test._
import zio.test.environment.TestClock
import zio.{Chunk, ExitCode, ZIO}
import zio.{Chunk, ExitCode, Queue, ZIO}

import java.util.Optional

Expand Down Expand Up @@ -57,6 +57,11 @@ object CommandSpec extends ZIOProcessBaseSpec {

assertM(zio)(equalTo("piped in"))
},
testM("accept file stdin") {
for {
lines <- Command("cat").stdin(ProcessInput.fromFile(new File("src/test/bash/echo-repeat.sh"))).lines
} yield assertTrue(lines.head == "#!/bin/bash")
},
testM("support different encodings") {
val zio =
Command("cat")
Expand Down Expand Up @@ -141,6 +146,19 @@ object CommandSpec extends ZIOProcessBaseSpec {
exit <- Command("ls").workingDirectory(new File("/some/bad/path")).lines.run
} yield assert(exit)(fails(isSubtype[CommandError.WorkingDirectoryMissing](anything)))
},
testM("connect to a repl-like process and flush the chunks eagerly and get responses right away") {
for {
commandQueue <- Queue.unbounded[Chunk[Byte]]
process <- Command("./stdin-echo.sh")
.workingDirectory(new File("src/test/bash"))
.stdin(ProcessInput.fromQueue(commandQueue))
.run
_ <- commandQueue.offer(Chunk.fromArray("line1\nline2\n".getBytes(StandardCharsets.UTF_8)))
_ <- commandQueue.offer(Chunk.fromArray("line3\n".getBytes(StandardCharsets.UTF_8)))
lines <- process.stdout.linesStream.take(3).runCollect
_ <- process.kill
} yield assertTrue(lines == Chunk("line1", "line2", "line3"))
},
testM("kill only kills parent process") {
for {
process <- Command("./sample-parent.sh").workingDirectory(new File("src/test/bash/kill-test")).run
Expand Down
1 change: 1 addition & 0 deletions website/sidebars.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"overview/overview_index",
"overview/overview_basics",
"overview/overview_piping",
"overview/overview_interactive_processes",
"overview/overview_other"
]
},
Expand Down

0 comments on commit 3880459

Please sign in to comment.