Skip to content

Commit

Permalink
feat: stdin handler
Browse files Browse the repository at this point in the history
  • Loading branch information
iseki0 committed Mar 29, 2024
1 parent fc49489 commit fd2cf6d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 21 deletions.
34 changes: 25 additions & 9 deletions src/main/java/space/iseki/cmdpipe/Cmd.java
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ class Builder {
private File workingDir;
private StreamProcessorImpl<InputStream, ?> stdoutProcessor;
private StreamProcessorImpl<InputStream, ?> stderrProcessor;
private StreamProcessorImpl<OutputStream, ?> stdinProcessor;
private ProcessBuilder[] pbs;
private boolean autoGrantExecutablePerm = false;

Expand Down Expand Up @@ -306,8 +307,12 @@ private static <T> T[] nonNullInArray(T[] arr) {
return autoGrantExecutableOnFailure(true);
}

private boolean configureRedirect(ProcessBuilder pb, Stdio stdio, boolean processorSet) {
boolean inherit = (inheritIO & 1 << stdio.i) > 0;
private boolean isInheritIO(Stdio stdio) {
return (inheritIO & 1 << stdio.i) > 0;
}

private void configureRedirect(ProcessBuilder pb, Stdio stdio, boolean processorSet) {
boolean inherit = isInheritIO(stdio);
var redirect = inherit ? ProcessBuilder.Redirect.INHERIT : switch (stdio) {
case STDERR, STDOUT -> processorSet ? ProcessBuilder.Redirect.PIPE : ProcessBuilder.Redirect.DISCARD;
case STDIN -> ProcessBuilder.Redirect.PIPE;
Expand All @@ -317,7 +322,6 @@ private boolean configureRedirect(ProcessBuilder pb, Stdio stdio, boolean proces
case STDOUT -> pb.redirectOutput(redirect);
case STDERR -> pb.redirectError(redirect);
}
return redirect == ProcessBuilder.Redirect.PIPE;
}

/**
Expand Down Expand Up @@ -476,6 +480,12 @@ private boolean configureRedirect(ProcessBuilder pb, Stdio stdio, boolean proces
return this;
}

public @NotNull Builder handleStdin(@NotNull StreamProcessor<OutputStream, ?> processor) {
inheritIO(Stdio.STDIN, false);
this.stdinProcessor = (StreamProcessorImpl<OutputStream, ?>) Objects.requireNonNull(processor);
return this;
}

/**
* Start the command.
*
Expand All @@ -491,29 +501,35 @@ private boolean configureRedirect(ProcessBuilder pb, Stdio stdio, boolean proces
CmdImpl cmd = null;
var stdoutProcessor = this.stdoutProcessor;
var stderrProcessor = this.stderrProcessor;
var stdinProcessor = this.stdinProcessor;
try {
var pbs = getPbs();
var lastPb = pbs[pbs.length - 1];
var firstPb = pbs[0];
var stdoutStart = configureRedirect(lastPb, Stdio.STDOUT, stdoutProcessor != null);
var stderrStart = configureRedirect(lastPb, Stdio.STDERR, stderrProcessor != null);
configureRedirect(firstPb, Stdio.STDIN, stdinProcessor != null);
configureRedirect(lastPb, Stdio.STDOUT, stdoutProcessor != null);
configureRedirect(lastPb, Stdio.STDERR, stderrProcessor != null);
for (ProcessBuilder pb : pbs) configureEnvAndDir(pb);
var processes = autoGrantExecutablePerm ? startRetryIfFailed(pbs) : start(pbs);
cmd = new CmdImpl(processes, executor);
var lastProcess = processes.get(processes.size() - 1);
if (stdoutStart) {
assert stdoutProcessor != null;
if (stdoutProcessor != null) {
cmd.startHandler(Stdio.STDOUT, stdoutProcessor, lastProcess.getInputStream());
}
if (stderrStart) {
assert stderrProcessor != null;
if (stderrProcessor != null) {
cmd.startHandler(Stdio.STDERR, stderrProcessor, lastProcess.getErrorStream());
}
if (stdinProcessor != null) {
cmd.startHandler(Stdio.STDIN, stdinProcessor, lastProcess.getOutputStream());
} else if (!isInheritIO(Stdio.STDIN)) {
processes.get(0).getOutputStream().close();
}
return cmd;
} catch (Throwable th) {
var spThrows = new RuntimeException("command start failed", th);
if (stdoutProcessor != null) stdoutProcessor.markFailed(spThrows);
if (stderrProcessor != null) stderrProcessor.markFailed(spThrows);
if (stdinProcessor != null) stdinProcessor.markFailed(spThrows);
// keep behavior consistent with ProcessBuilder.startPipeline()
if (cmd != null) {
cmd.stopAll(true);
Expand Down
50 changes: 38 additions & 12 deletions src/test/kotlin/space/iseki/cmdpipe/CmdTest.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package space.iseki.cmdpipe

import org.junit.jupiter.api.Assumptions
import org.junit.jupiter.api.*
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.assertTimeoutPreemptively
import java.io.IOException
import java.nio.charset.Charset
import java.time.Duration
Expand Down Expand Up @@ -80,35 +78,36 @@ class CmdTest {

@Test
fun testRunNodeTimeoutKill() {
val node = Cmd.Builder().cmdline("node").start()
val stdin = Cmd.output { Thread.sleep(8) }
val node = startSkipTestIfNotFound(Cmd.Builder().cmdline("node").handleStdin(stdin))
try {
val f = node.backgroundWaitTimeoutKill(100, TimeUnit.MILLISECONDS)
assertSame(f, node.backgroundWaitTimeoutKill(100, TimeUnit.MILLISECONDS))
val t = measureTimeMillis {
assertTimeoutPreemptively(Duration.ofSeconds(1)) {
assertFalse(f.get())
}
assertTimeout(Duration.ofSeconds(1)) {
try {
stdin.future().get()
} catch (ignored: InterruptedException) {
}
}
}
println(t)
assertTrue(t > 100)
} finally {
if (node.process.isAlive) node.stopAll(true)
}

}

@Test
fun testRunNodeKill() {
val stdout = Cmd.input {
it.stream.bufferedReader(Charset.defaultCharset()).readText()
}
val node = try {
Cmd.Builder().cmdline("node").handleStdout(stdout).start()
} catch (e: IOException) {
if (e.message!!.contains("error=2")) {
Assumptions.assumeTrue(false, "node not found")
}
throw e
}
val node = startSkipTestIfNotFound(Cmd.Builder().cmdline("node").handleStdout(stdout))
val p = node.process
assertTrue(p.isAlive)
node.stopAll(true)
Expand All @@ -129,4 +128,31 @@ class CmdTest {
}
}

@Test
fun testNodeInteractive() {
val stdin = Cmd.output { (_, _, stream) ->
stream.use { it.write("console.log(12345+54321);".encodeToByteArray()) }
}
val stdout = Cmd.input { (_, _, s) ->
s.readAllBytes().decodeToString()
}
val stderr = Cmd.input { (_, _, s) ->
s.readAllBytes().decodeToString()
}
val cmd = Cmd.Builder().cmdline("node").handleStdin(stdin).handleStdout(stdout).handleStderr(stderr)
val f = startSkipTestIfNotFound(cmd).backgroundWaitTimeoutKill(3, TimeUnit.SECONDS)
assertTrue(f.get(), "node not exit")
assertContains(stdout.future().get(), "66666")
assertTrue(stderr.future().get().isEmpty())
}


private fun startSkipTestIfNotFound(cmd: Cmd.Builder) = try {
cmd.start()
} catch (e: IOException) {
if (e.message!!.contains("error=2")) {
Assumptions.assumeTrue(false, "executable file not found")
}
throw e
}
}

0 comments on commit fd2cf6d

Please sign in to comment.