diff --git a/msgpack-core/src/main/java/org/msgpack/core/buffer/ChannelBufferInput.java b/msgpack-core/src/main/java/org/msgpack/core/buffer/ChannelBufferInput.java index f00cb0c30..e8d7c1de8 100644 --- a/msgpack-core/src/main/java/org/msgpack/core/buffer/ChannelBufferInput.java +++ b/msgpack-core/src/main/java/org/msgpack/core/buffer/ChannelBufferInput.java @@ -62,14 +62,12 @@ public MessageBuffer next() throws IOException { ByteBuffer b = buffer.sliceAsByteBuffer(); - while (b.remaining() > 0) { - int ret = channel.read(b); - if (ret == -1) { - break; - } + int ret = channel.read(b); + if (ret == -1) { + return null; } b.flip(); - return b.remaining() == 0 ? null : buffer.slice(0, b.limit()); + return buffer.slice(0, b.limit()); } @Override diff --git a/msgpack-core/src/test/scala/org/msgpack/core/buffer/MessageBufferInputTest.scala b/msgpack-core/src/test/scala/org/msgpack/core/buffer/MessageBufferInputTest.scala index 6b1c0da48..a7653797e 100644 --- a/msgpack-core/src/test/scala/org/msgpack/core/buffer/MessageBufferInputTest.scala +++ b/msgpack-core/src/test/scala/org/msgpack/core/buffer/MessageBufferInputTest.scala @@ -16,10 +16,13 @@ package org.msgpack.core.buffer import java.io._ +import java.net.{InetSocketAddress} import java.nio.ByteBuffer +import java.nio.channels.{ServerSocketChannel, SocketChannel} +import java.util.concurrent.{Callable, Executors, TimeUnit} import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import org.msgpack.core.{MessagePack, MessagePackSpec, MessageUnpacker} +import org.msgpack.core.{MessagePack, MessagePackSpec} import xerial.core.io.IOUtil._ import scala.util.Random @@ -201,5 +204,44 @@ class MessageBufferInputTest buf.reset(in1) readInt(buf) shouldBe 42 } + + "unpack without blocking" in { + val server = ServerSocketChannel.open.bind(new InetSocketAddress("localhost", 0)) + val executorService = Executors.newCachedThreadPool + + try { + executorService.execute(new Runnable { + override def run { + val server_ch = server.accept + val packer = MessagePack.newDefaultPacker(server_ch) + packer.packString("0123456789") + packer.flush + // Keep the connection open + while (!executorService.isShutdown) { + TimeUnit.SECONDS.sleep(1) + } + packer.close + } + }) + + val future = executorService.submit(new Callable[String] { + override def call: String = { + val conn_ch = SocketChannel.open(new InetSocketAddress("localhost", server.socket.getLocalPort)) + val unpacker = MessagePack.newDefaultUnpacker(conn_ch) + val s = unpacker.unpackString + unpacker.close + s + } + }) + + future.get(5, TimeUnit.SECONDS) shouldBe "0123456789" + } + finally { + executorService.shutdown + if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) { + executorService.shutdownNow + } + } + } } }