Skip to content

Commit

Permalink
fix: handle unexpected pipe closures
Browse files Browse the repository at this point in the history
fix: handle unexpected pipe closures
  • Loading branch information
vyfor authored May 10, 2024
2 parents 202bce0 + 25ff2a0 commit 062bb0e
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 24 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

```gradle
dependencies {
implementation("io.github.vyfor:kpresence:0.6.1")
implementation("io.github.vyfor:kpresence:0.6.2")
}
```

Expand Down
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ plugins {
}

group = "io.github.vyfor"
version = "0.6.1"
version = "0.6.2"

repositories {
mavenCentral()
Expand Down
37 changes: 32 additions & 5 deletions src/commonMain/kotlin/io/github/vyfor/kpresence/RichClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlin.concurrent.Volatile

/**
* Manages client connections and activity updates for Discord presence.
Expand All @@ -28,6 +29,7 @@ class RichClient(
private var signal = Mutex(true)
private var lastActivity: Activity? = null

@Volatile
var connectionState = ConnectionState.DISCONNECTED
private set
var onReady: (RichClient.() -> Unit)? = null
Expand Down Expand Up @@ -143,7 +145,7 @@ class RichClient(
return this
}

connection.write(2, null)
write(2, null)
connectionState = ConnectionState.DISCONNECTED
connection.close()
lastActivity = null
Expand Down Expand Up @@ -186,17 +188,17 @@ class RichClient(
debug(packet)
}

connection.write(1, packet)
write(1, packet)
}

private fun handshake() {
connection.write(0, "{\"v\": 1,\"client_id\":\"$clientId\"}")
write(0, "{\"v\": 1,\"client_id\":\"$clientId\"}")
}

private fun listen(): Job {
return coroutineScope.launch {
while (isActive && connectionState != ConnectionState.DISCONNECTED) {
val response = connection.read() ?: continue
val response = read() ?: if (connectionState == ConnectionState.DISCONNECTED) break else continue
logger?.apply {
trace("Received response:")
trace("Message(opcode: ${response.opcode}, data: ${response.data.decodeToString()})")
Expand Down Expand Up @@ -225,13 +227,38 @@ class RichClient(
lastActivity = null
logger?.warn("The connection was forcibly closed")
onDisconnect?.invoke(this@RichClient)
break
}
break
}
}
}
}
}

private fun read(): Message? {
return try {
connection.read()
} catch (e: ConnectionClosedException) {
connectionState = ConnectionState.DISCONNECTED
logger?.warn("The connection was forcibly closed: ${e.message?.trimEnd()}. Client will be disconnected")
connection.close()
lastActivity = null
onDisconnect?.invoke(this@RichClient)
null
}
}

private fun write(opcode: Int, data: String?) {
try {
connection.write(opcode, data)
} catch (e: ConnectionClosedException) {
connectionState = ConnectionState.DISCONNECTED
logger?.warn("The connection was forcibly closed: ${e.message?.trimEnd()}. Client will be disconnected")
connection.close()
lastActivity = null
onDisconnect?.invoke(this@RichClient)
}
}
}

enum class ConnectionState {
Expand Down
6 changes: 2 additions & 4 deletions src/commonTest/kotlin/io/github/vyfor/kpresence/ClientTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class ClientTest {
continue
}

if (client.connectionState == ConnectionState.DISCONNECTED) break

if (input == "clear") {
client.clear()
continue
Expand All @@ -43,9 +45,5 @@ class ClientTest {
state = "KPresence"
}
}

if (client.connectionState != ConnectionState.DISCONNECTED) {
client.shutdown()
}
}
}
51 changes: 44 additions & 7 deletions src/jvmMain/kotlin/io/github/vyfor/kpresence/ipc/Connection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package io.github.vyfor.kpresence.ipc
import io.github.vyfor.kpresence.exception.*
import io.github.vyfor.kpresence.utils.putInt
import io.github.vyfor.kpresence.utils.reverseBytes
import java.io.IOException
import java.lang.System.getenv
import java.net.SocketException
import java.net.UnixDomainSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousCloseException
import java.nio.channels.AsynchronousFileChannel
import java.nio.channels.ClosedChannelException
import java.nio.channels.SocketChannel
import java.nio.file.InvalidPathException
import java.nio.file.StandardOpenOption
Expand Down Expand Up @@ -62,19 +65,24 @@ actual class Connection {
throw PipeNotFoundException()
}

override fun read(): Message {
override fun read(): Message? {
pipe?.let { stream ->
try {
val opcode = stream.readInt(0).reverseBytes()
val length = stream.readInt(4).reverseBytes()
val buffer = ByteBuffer.allocate(length)

stream.read(buffer, 8).get()
return Message(
opcode,
buffer.array()
)
} catch (e: AsynchronousCloseException) {
return null
} catch (e: ClosedChannelException) {
throw ConnectionClosedException(e.message.orEmpty())
} catch (e: Exception) {
if (e.cause?.message == "The pipe has been ended") throw ConnectionClosedException(e.message.orEmpty())
throw PipeReadException(e.message.orEmpty())
}
} ?: throw NotConnectedException()
Expand All @@ -93,7 +101,12 @@ actual class Connection {
}

stream.write(ByteBuffer.wrap(buffer), 0).get()
} catch (e: AsynchronousCloseException) {
return
} catch (e: ClosedChannelException) {
throw ConnectionClosedException(e.message.orEmpty())
} catch (e: Exception) {
if (e.cause?.message == "The pipe is being closed") throw ConnectionClosedException(e.message.orEmpty())
throw PipeWriteException(e.message.orEmpty())
}
} ?: throw NotConnectedException()
Expand Down Expand Up @@ -150,19 +163,31 @@ actual class Connection {
throw PipeNotFoundException()
}

override fun read(): Message {
override fun read(): Message? {
pipe?.let { stream ->
try {
val opcode = stream.readInt().reverseBytes()
val length = stream.readInt().reverseBytes()
val buffer = ByteBuffer.allocate(length)

stream.read(buffer)
val bytesRead = stream.read(buffer)
if (bytesRead == 0) throw ConnectionClosedException("The pipe has been closed")

return Message(
opcode,
buffer.array()
)
} catch (e: AsynchronousCloseException) {
return null
} catch (e: ConnectionClosedException) {
throw e
} catch (e: ClosedChannelException) {
throw ConnectionClosedException(e.message.orEmpty())
} catch (e: SocketException) {
if (e.message == "Connection reset") throw ConnectionClosedException(e.message.orEmpty())
throw PipeReadException(e.message.orEmpty())
} catch (e: Exception) {
if (e.message == "Broken pipe") throw ConnectionClosedException(e.message.orEmpty())
throw PipeReadException(e.message.orEmpty())
}
} ?: throw NotConnectedException()
Expand All @@ -179,9 +204,19 @@ actual class Connection {
buffer.putInt(bytes.size.reverseBytes(), 4)
bytes.copyInto(buffer, 8)
}

stream.write(ByteBuffer.wrap(buffer))
val bytesWritten = stream.write(ByteBuffer.wrap(buffer))
if (bytesWritten == 0) throw ConnectionClosedException("The pipe has been closed")
} catch (e: AsynchronousCloseException) {
return
} catch (e: ConnectionClosedException) {
throw e
} catch (e: ClosedChannelException) {
throw ConnectionClosedException(e.message.orEmpty())
} catch (e: SocketException) {
if (e.message == "Connection reset") throw ConnectionClosedException(e.message.orEmpty())
throw PipeWriteException(e.message.orEmpty())
} catch (e: Exception) {
if (e.message == "Broken pipe") throw ConnectionClosedException(e.message.orEmpty())
throw PipeWriteException(e.message.orEmpty())
}
} ?: throw NotConnectedException()
Expand All @@ -195,7 +230,9 @@ actual class Connection {
private fun SocketChannel.readInt(): Int {
val buffer = ByteBuffer.allocate(4)

read(buffer)
val bytesRead = read(buffer)
if (bytesRead == 0) throw ConnectionClosedException("The pipe has been closed")

return ((buffer[0].toUInt() shl 24) +
(buffer[1].toUInt() shl 16) +
(buffer[2].toUInt() shl 8) +
Expand Down
4 changes: 2 additions & 2 deletions src/jvmTest/kotlin/io/github/vyfor/kpresence/ClientTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class JVMClientTest {

client.connect(true)

repeat(2) {
repeat(20) {
client.update {
details = Random.nextInt().toString()
state = "KPresence"
}
Thread.sleep(15000)
Thread.sleep(1000)
}

client.shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ actual class Connection {
if (bytesWritten < 0) {
close()

if (errno == ECONNRESET || errno == ESHUTDOWN) throw ConnectionClosedException(strerror(errno)?.toKString().orEmpty())
throw PipeWriteException(strerror(errno)?.toKString().orEmpty())
}
}
Expand All @@ -97,6 +98,7 @@ actual class Connection {
recv(pipe, bytes.refTo(0), bytes.size.convert(), MSG_DONTWAIT).let { bytesRead ->
if (bytesRead < 0L) {
if (errno == EAGAIN || errno == EWOULDBLOCK) return null
if (errno == ECONNRESET || errno == ESHUTDOWN) throw ConnectionClosedException(strerror(errno)?.toKString().orEmpty())
throw PipeReadException(strerror(errno)?.toKString().orEmpty())
} else if (bytesRead == 0L) {
close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ actual class Connection {
if (bytesWritten < 0) {
close()

if (errno == ECONNRESET || errno == ESHUTDOWN) throw ConnectionClosedException(strerror(errno)?.toKString().orEmpty())
throw PipeWriteException(strerror(errno)?.toKString().orEmpty())
}
}
Expand All @@ -97,6 +98,7 @@ actual class Connection {
recv(pipe, bytes.refTo(0), bytes.size.convert(), MSG_DONTWAIT).let { bytesRead ->
if (bytesRead < 0L) {
if (errno == EAGAIN || errno == EWOULDBLOCK) return null
if (errno == ECONNRESET || errno == ESHUTDOWN) throw ConnectionClosedException(strerror(errno)?.toKString().orEmpty())
throw PipeReadException(strerror(errno)?.toKString().orEmpty())
} else if (bytesRead == 0L) {
close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,14 @@ actual class Connection {
WriteFile(handle, it.addressOf(0), buffer.size.convert(), null, null)
}

if (success == FALSE) throw PipeWriteException(formatError(GetLastError()))
if (success == FALSE) {
val err = GetLastError()

if (err.toInt() == ERROR_BROKEN_PIPE) {
throw ConnectionClosedException(formatError(err))
}
throw PipeWriteException(formatError(err))
}
} ?: throw NotConnectedException()
}

Expand All @@ -80,7 +87,12 @@ actual class Connection {
)

if (result == FALSE) {
throw PipeReadException(formatError(GetLastError()))
val err = GetLastError()

if (err.toInt() == ERROR_BROKEN_PIPE) {
throw ConnectionClosedException(formatError(err))
}
throw PipeReadException(formatError(err))
}
if (bytesAvailable.value == 0u) {
return null
Expand All @@ -89,9 +101,14 @@ actual class Connection {
val bytes = ByteArray(size)
val bytesRead = alloc<UIntVar>()
bytes.usePinned { pinnedBytes ->
ReadFile(pipe, pinnedBytes.addressOf (0), size.convert(), bytesRead.ptr, null).let { success ->
ReadFile(pipe, pinnedBytes.addressOf(0), size.convert(), bytesRead.ptr, null).let { success ->
if (success == FALSE) {
throw PipeReadException(formatError(GetLastError()))
val err = GetLastError()

if (err.toInt() == ERROR_BROKEN_PIPE) {
throw ConnectionClosedException(formatError(err))
}
throw PipeReadException(formatError(err))
}
}
}
Expand Down

0 comments on commit 062bb0e

Please sign in to comment.