Skip to content

Commit

Permalink
Fix issues found after adding async-await support (#44)
Browse files Browse the repository at this point in the history
* Do not process messages from previous joins
* Allow rejoining a channel
  • Loading branch information
atdrendel authored Jan 1, 2023
1 parent aa8ac6c commit b141808
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 67 deletions.
115 changes: 73 additions & 42 deletions Sources/Phoenix/Channel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,6 @@ extension Channel {
}

let fut = sync { () -> JoinFuture? in
guard shouldRejoin else {
precondition(joinFuture == nil)
return nil
}

self.customTimeout = customTimeout

os_log(
Expand Down Expand Up @@ -292,6 +287,10 @@ extension Channel {
case .joining, .joined:
return
case .closed, .errored, .leaving:
if joinFuture == nil {
joinFuture = JoinFuture()
}

let ref = socket.advanceRef()
self.state = .joining(ref)
self.writeJoinPushAsync()
Expand All @@ -303,7 +302,13 @@ extension Channel {
sync {
switch self.state {
case let .joining(joinRef):
let message = OutgoingMessage(joinPush, ref: joinRef, joinRef: joinRef)
precondition(joinFuture != nil)

let message = OutgoingMessage(
joinPush,
ref: joinRef,
joinRef: joinRef
)

createJoinTimer()

Expand Down Expand Up @@ -536,7 +541,7 @@ extension Channel {
self.sync {
// put it back to try again later
self._inFlight[ref] = nil
self._pending.append(push)
self._pending.insert(push, at: 0)
}
} else {
self.flushAsync()
Expand Down Expand Up @@ -599,15 +604,17 @@ extension Channel {
private func timeoutInFlightMessages() {
sync {
// invalidate a previous timer if it's there
self.inFlightMessagesTimer?.invalidate()
self.inFlightMessagesTimer = nil

guard !_inFlight.isEmpty else { return }

let now = DispatchTime.now()

let messages = _inFlight.values.sortedByTimeoutDate().filter {
$0.timeoutDate < now
}
let messages = _inFlight
.values
.sortedByTimeoutDate()
.filter { $0.timeoutDate < now }

for message in messages {
_inFlight[message.ref] = nil
Expand All @@ -626,18 +633,20 @@ extension Channel {
sync {
guard _inFlight.isNotEmpty else { return }

let possibleNext = _inFlight.values.sortedByTimeoutDate().first
let possibleNext = _inFlight
.values
.sortedByTimeoutDate()
.first

guard let next = possibleNext else { return }
guard next.timeoutDate < inFlightMessagesTimer?.nextDeadline else { return }

self
.inFlightMessagesTimer = DispatchTimer(
fireAt: next
.timeoutDate
) { [weak self] in
self?.timeoutInFlightMessagesAsync()
}
guard next.timeoutDate < inFlightMessagesTimer?.nextDeadline
else { return }

self.inFlightMessagesTimer = DispatchTimer(
fireAt: next.timeoutDate
) { [weak self] in
self?.timeoutInFlightMessagesAsync()
}
}
}
}
Expand Down Expand Up @@ -700,6 +709,9 @@ extension Channel {
let receiveValue = { [weak self] (input: SocketOutput) in
switch input {
case let .channelMessage(message):
guard message.joinRef == nil ||
message.joinRef == self?.joinRef
else { return }
self?.handle(message)
case .socketOpen:
self?.handleSocketOpen()
Expand All @@ -725,6 +737,9 @@ extension Channel {
case .joining:
writeJoinPushAsync()
case .errored where shouldRejoin:
if joinFuture == nil {
joinFuture = JoinFuture()
}
let ref = socket.advanceRef()
self.state = .joining(ref)
writeJoinPushAsync()
Expand Down Expand Up @@ -797,43 +812,65 @@ extension Channel {

private func handle(_ reply: Channel.Reply) {
sync {
switch state {
case let .joining(joinRef):
guard reply.ref == joinRef,
reply.joinRef == joinRef,
reply.isOk
func putInFlightBackIntoQueue(_ pushed: PushedMessage) {
guard pushed.push.event.isCustom else { return }

self._inFlight[reply.ref] = nil
self._pending.append(pushed.push)
}

if case let .joined(joinRef) = state {
guard let pushed = _inFlight.removeValue(forKey: reply.ref)
else { return }

if reply.joinRef == joinRef {
pushed.callback(reply: reply)
} else {
putInFlightBackIntoQueue(pushed)
}

return
}

if case let .joining(joinRef) = state {
guard reply.ref == joinRef, reply.joinRef == joinRef
else {
let fut = joinFuture
joinFuture = nil
fut?.fail(Channel.Error.invalidJoinReply(reply))
if let pushed = _inFlight.removeValue(forKey: reply.ref) {
putInFlightBackIntoQueue(pushed)
}
return
}

guard reply.isOk else {
self.errored(Channel.Error.invalidJoinReply(reply))
self.createRejoinTimer()
break
return
}

self.state = .joined(joinRef)

let subject = self.subject
notifySubjectQueue.async { subject.send(.join(reply.message)) }
notifySubjectQueue.async {
subject.send(.join(reply.message))
}

leaveFuture?.resolve()
joinFuture?.resolve(reply.message.payload)

self._joinTimer = .off

flushAsync()
}

case let .joined(joinRef):
guard let pushed = _inFlight.removeValue(forKey: reply.ref),
if case let .leaving(joinRef, leavingRef) = state {
guard reply.ref == leavingRef,
reply.joinRef == joinRef
else {
if let pushed = _inFlight.removeValue(forKey: reply.ref) {
putInFlightBackIntoQueue(pushed)
}
return
}
pushed.callback(reply: reply)

case let .leaving(joinRef, leavingRef):
guard reply.ref == leavingRef, reply.joinRef == joinRef else { break }

self.state = .closed
self.sendLeaveAndCompletionToSubjectAsync()
Expand All @@ -845,12 +882,6 @@ extension Channel {
let leaveFut = leaveFuture
leaveFuture = nil
leaveFut?.resolve()

case .closed:
break

default:
break
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions Sources/Phoenix/PhxEvent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,9 @@ public enum PhxEvent: Equatable, ExpressibleByStringLiteral {
public static func == (lhs: PhxEvent, rhs: PhxEvent) -> Bool {
lhs.stringValue == rhs.stringValue
}

var isCustom: Bool {
guard case .custom = self else { return false }
return true
}
}
48 changes: 27 additions & 21 deletions Sources/Phoenix/Socket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public final class Socket {
private let subject = PassthroughSubject<Output, Failure>()
private var shouldReconnect = true
private var webSocketSubscriber: AnyCancellable?
private var channels = [Topic: WeakChannel]()
private var channels = [Topic: Channel]()

private var connectFuture: ConnectFuture?

Expand All @@ -36,8 +36,7 @@ public final class Socket {
#endif

public var joinedChannels: [Channel] {
let channels = sync { self.channels }
return channels.compactMap(\.value.channel)
sync { Array(self.channels.values) }
}

private var pending: [Push] = []
Expand Down Expand Up @@ -237,7 +236,7 @@ extension Socket: ConnectablePublisher {

switch state {
case .closed:
precondition(connectFuture == nil)
connectFuture?.fail(CancellationError())
let fut = ConnectFuture()
connectFuture = fut

Expand All @@ -258,10 +257,13 @@ extension Socket: ConnectablePublisher {
return fut
}

case .connecting, .open:
case .connecting:
let fut = connectFuture
return { fut }

case .open:
return { nil }

case .closing:
let fut = connectFuture
return { fut }
Expand Down Expand Up @@ -290,7 +292,7 @@ extension Socket: ConnectablePublisher {

// Calling `Channel.leave()` inside `sync` can cause a deadlock.
let channels: [Channel] = sync {
let channels = self.channels.compactMap(\.value.channel)
let channels = Array(self.channels.values)
self.channels.removeAll()
return channels
}
Expand Down Expand Up @@ -350,7 +352,8 @@ extension Socket: ConnectablePublisher {
return result
}

private func reconnectIfPossible() {
@discardableResult
private func reconnectIfPossible() -> Bool {
sync {
if shouldReconnect {
_reconnectAttempts += 1
Expand All @@ -361,6 +364,9 @@ extension Socket: ConnectablePublisher {
guard self.lock.locked({ self.shouldReconnect }) else { return }
self.connect()
}
return true
} else {
return false
}
}
}
Expand All @@ -387,17 +393,19 @@ public extension Socket {

func channel(_ topic: Topic, payload: Payload = [:]) -> Channel {
sync {
if let weakChannel = channels[topic],
let _channel = weakChannel.channel
{
return _channel
if let channel = channels[topic] {
return channel
}

let _channel = Channel(topic: topic, joinPayload: payload, socket: self)
let channel = Channel(
topic: topic,
joinPayload: payload,
socket: self
)

channels[topic] = WeakChannel(_channel)
channels[topic] = channel

return _channel
return channel
}
}

Expand All @@ -420,9 +428,7 @@ public extension Socket {

@discardableResult
private func removeChannel(for topic: Topic) -> Channel? {
let weakChannel = sync { channels.removeValue(forKey: topic) }
guard let channel = weakChannel?.channel else { return nil }
return channel
sync { channels.removeValue(forKey: topic) }
}
}

Expand Down Expand Up @@ -750,10 +756,10 @@ extension Socket {
let subject = self.subject
notifySubjectQueue.async { subject.send(.close) }

connectFuture?.resolve()
connectFuture = nil

reconnectIfPossible()
if !reconnectIfPossible() {
connectFuture?.resolve()
connectFuture = nil
}
}
}
}
Expand Down
5 changes: 1 addition & 4 deletions Tests/PhoenixTests/ChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,8 @@ class ChannelTests: XCTestCase {

do {
try await channel.join()
XCTFail("Should have failed")
} catch {
guard let e = error as? Channel.Error,
case .unableToJoin = e
else { return XCTFail("Should have failed with .unableToJoin") }
XCTFail("Join should have succeeded but received error: \(error)")
}
}

Expand Down

0 comments on commit b141808

Please sign in to comment.