diff --git a/gbn/gbn_client.go b/gbn/gbn_client.go index 6cb43b27..4f3128f1 100644 --- a/gbn/gbn_client.go +++ b/gbn/gbn_client.go @@ -21,12 +21,7 @@ func NewClientConn(ctx context.Context, n uint8, sendFunc sendBytesFunc, math.MaxUint8) } - conn := newGoBackNConn(ctx, sendFunc, receiveFunc, false, n) - - // Apply functional options - for _, o := range opts { - o(conn) - } + conn := newGoBackNConn(ctx, sendFunc, receiveFunc, false, n, opts...) if err := conn.clientHandshake(); err != nil { if err := conn.Close(); err != nil { diff --git a/gbn/gbn_conn.go b/gbn/gbn_conn.go index ab52bd69..4cebc5c2 100644 --- a/gbn/gbn_conn.go +++ b/gbn/gbn_conn.go @@ -113,21 +113,20 @@ type GoBackNConn struct { // newGoBackNConn creates a GoBackNConn instance with all the members which // are common between client and server initialised. +// +//nolint:varnamelen func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc, - recvFunc recvBytesFunc, isServer bool, n uint8) *GoBackNConn { + recvFunc recvBytesFunc, isServer bool, n uint8, + opts ...Option) *GoBackNConn { ctxc, cancel := context.WithCancel(ctx) - return &GoBackNConn{ - n: n, - s: n + 1, + gbn := &GoBackNConn{ resendTimeout: defaultResendTimeout, recvFromStream: recvFunc, sendToStream: sendFunc, - recvDataChan: make(chan *PacketData, n), sendDataChan: make(chan *PacketData), isServer: isServer, - sendQueue: newQueue(n+1, defaultHandshakeTimeout), handshakeTimeout: defaultHandshakeTimeout, recvTimeout: DefaultRecvTimeout, sendTimeout: DefaultSendTimeout, @@ -138,6 +137,14 @@ func newGoBackNConn(ctx context.Context, sendFunc sendBytesFunc, cancel: cancel, quit: make(chan struct{}), } + + for _, o := range opts { + o(gbn) + } + + gbn.setN(n) + + return gbn } // setN sets the current N to use. This _must_ be set before the handshake is @@ -146,7 +153,12 @@ func (g *GoBackNConn) setN(n uint8) { g.n = n g.s = n + 1 g.recvDataChan = make(chan *PacketData, n) - g.sendQueue = newQueue(n+1, defaultHandshakeTimeout) + g.sendQueue = newQueue(&queueConfig{ + s: g.s, + sendPkt: func(packet *PacketData) error { + return g.sendPacket(g.ctx, packet) + }, + }) } // SetSendTimeout sets the timeout used in the Send function. @@ -348,6 +360,8 @@ func (g *GoBackNConn) Close() error { // initialisation. g.cancel() + g.sendQueue.stop() + g.wg.Wait() if g.pingTicker != nil { @@ -387,9 +401,17 @@ func (g *GoBackNConn) sendPacket(ctx context.Context, msg Message) error { func (g *GoBackNConn) sendPacketsForever() error { // resendQueue re-sends the current contents of the queue. resendQueue := func() error { - return g.sendQueue.resend(func(packet *PacketData) error { - return g.sendPacket(g.ctx, packet) - }) + err := g.sendQueue.resend(g.resendTimeout) + + // After resending the queue, we reset the resend ticker. + // This is so that we don't immediately resend the queue again, + // if the sendQueue.resend call above took a long time to + // execute. That can happen if the function was awaiting the + // expected ACK for a long time, or times out while awaiting the + // catch up. + g.resendTicker.Reset(g.resendTimeout) + + return err } for { @@ -578,7 +600,10 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo } case *PacketACK: - gotValidACK := g.sendQueue.processACK(m.Seq) + gotValidACK := g.sendQueue.processACK( + m.Seq, g.resendTimeout, + ) + if gotValidACK { g.resendTicker.Reset(g.resendTimeout) @@ -597,15 +622,12 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo // sent was dropped, or maybe we sent a duplicate // message. The NACK message contains the sequence // number that the receiver was expecting. - inQueue, bumped := g.sendQueue.processNACK(m.Seq) - - // If the NACK sequence number is not in our queue - // then we ignore it. We must have received the ACK - // for the sequence number in the meantime. - if !inQueue { - log.Tracef("NACK seq %d is not in the queue. "+ - "Ignoring. (isServer=%v)", m.Seq, - g.isServer) + shouldResend, bumped := g.sendQueue.processNACK(m.Seq) + + // If we don't need to resend the queue after processing + // the NACK, we can continue without sending the resend + // signal. + if !shouldResend { continue } diff --git a/gbn/gbn_server.go b/gbn/gbn_server.go index 488a45db..ec8a7d93 100644 --- a/gbn/gbn_server.go +++ b/gbn/gbn_server.go @@ -14,12 +14,7 @@ import ( func NewServerConn(ctx context.Context, sendFunc sendBytesFunc, recvFunc recvBytesFunc, opts ...Option) (*GoBackNConn, error) { - conn := newGoBackNConn(ctx, sendFunc, recvFunc, true, DefaultN) - - // Apply functional options - for _, o := range opts { - o(conn) - } + conn := newGoBackNConn(ctx, sendFunc, recvFunc, true, DefaultN, opts...) if err := conn.serverHandshake(); err != nil { if err := conn.Close(); err != nil { diff --git a/gbn/queue.go b/gbn/queue.go index 23de65a0..90eb4186 100644 --- a/gbn/queue.go +++ b/gbn/queue.go @@ -5,13 +5,25 @@ import ( "time" ) -// queue is a fixed size queue with a sliding window that has a base and a top -// modulo s. -type queue struct { - // content is the current content of the queue. This is always a slice - // of length s but can contain nil elements if the queue isn't full. - content []*PacketData +const ( + // awaitingTimeoutMultiplier defines the multiplier we use when + // multiplying the resend timeout during a resend catch up, resulting in + // duration we wait for the resend catch up to complete before timing + // out. + // We set this to 3X the resend timeout. The reason we wait exactly 3X + // the resend timeout is that we expect that the max time correct + // behavior would take, would be: + // * 1X the resendTimeout for the time it would take for the party + // respond with an ACK for the last packet in the resend queue, i.e. the + // awaitedACK. + // * 1X the resendTimeout while awaiting the proceedAfterTime callback + // to be executed. + // * 1X extra resendTimeout as buffer, to ensure that we have enough + // time to process the ACKS/NACKS by other party + some extra margin. + awaitingTimeoutMultiplier = 3 +) +type queueConfig struct { // s is the maximum sequence number used to label packets. Packets // are labelled with incrementing sequence numbers modulo s. // s must be strictly larger than the window size, n. This @@ -21,6 +33,18 @@ type queue struct { // no way to tell. s uint8 + sendPkt func(packet *PacketData) error +} + +// queue is a fixed size queue with a sliding window that has a base and a top +// modulo s. +type queue struct { + cfg *queueConfig + + // content is the current content of the queue. This is always a slice + // of length s but can contain nil elements if the queue isn't full. + content []*PacketData + // sequenceBase keeps track of the base of the send window and so // represents the next ack that we expect from the receiver. The // maximum value of sequenceBase is s. @@ -39,19 +63,66 @@ type queue struct { // topMtx is used to guard sequenceTop. topMtx sync.RWMutex - lastResend time.Time - handshakeTimeout time.Duration + // awaitedACK defines the sequence number for the last packet in the + // resend queue. If we receive an ACK for this sequence number during + // the resend catch up, we wait for the duration of the resend timeout, + // and then proceed to send new packets, unless we receive the + // awaitedNACK during the wait time. If that happens, we will proceed + // send new packets as soon as we have processed the NACK. + awaitedACK uint8 + + // awaitedNACK defines the sequence number that in case we get a NACK + // with that sequence number during the resend catch up, we'd consider + // the catch up to be complete and we can proceed to send new packets. + awaitedNACK uint8 + + // awaitingCatchUp is set to true if we are awaiting a catch up after we + // have resent the queue. + awaitingCatchUp bool + + // awaitingCatchUpMu must be held when accessing or mutating the values + // of, awaitedACK, catchUpID, awaitedNACK and awaitingCatchUp. + awaitingCatchUpMu sync.RWMutex + + // awaitedACKSignal is used to signal that we have received the awaited + // ACK after resending the queue, and have waited for the duration of + // the resend timeout. Once this signal is received, we can proceed to + // send new packets. + awaitedACKSignal chan struct{} + + // awaitedNACKSignal is used to signal that we have received the awaited + // NACK after resending the queue. Once this signal is received, we can + // proceed to send new packets. + awaitedNACKSignal chan struct{} + + // caughtUpSignal is used to signal that we have caught up after + // awaiting the catch up after resending the queue. + // This channel will be a unique channel for every cycle of the resend + // catch up, and is used to ensure that we only send an awaitedACKSignal + // if we're still awaiting the resend catch up, after we have waited for + // the resend timeout after the awaitedACK has been received. + caughtUpSignal chan struct{} + + lastResend time.Time + + quit chan struct{} } // newQueue creates a new queue. -func newQueue(s uint8, handshakeTimeout time.Duration) *queue { +func newQueue(cfg *queueConfig) *queue { return &queue{ - content: make([]*PacketData, s), - s: s, - handshakeTimeout: handshakeTimeout, + cfg: cfg, + content: make([]*PacketData, cfg.s), + awaitedACKSignal: make(chan struct{}, 1), + awaitedNACKSignal: make(chan struct{}, 1), + quit: make(chan struct{}), } } +func (q *queue) stop() { + close(q.quit) +} + // size is used to calculate the current sender queueSize. func (q *queue) size() uint8 { q.baseMtx.RLock() @@ -64,7 +135,7 @@ func (q *queue) size() uint8 { return q.sequenceTop - q.sequenceBase } - return q.sequenceTop + (q.s - q.sequenceBase) + return q.sequenceTop + (q.cfg.s - q.sequenceBase) } // addPacket adds a new packet to the queue. @@ -74,12 +145,68 @@ func (q *queue) addPacket(packet *PacketData) { packet.Seq = q.sequenceTop q.content[q.sequenceTop] = packet - q.sequenceTop = (q.sequenceTop + 1) % q.s + q.sequenceTop = (q.sequenceTop + 1) % q.cfg.s } -// resend invokes the callback for each packet that needs to be re-sent. -func (q *queue) resend(cb func(packet *PacketData) error) error { - if time.Since(q.lastResend) < q.handshakeTimeout { +// resend resends the current contents of the queue, by invoking the callback +// for each packet that needs to be resent, and then awaits that we either +// receive the expected ACK or NACK after resending the queue, before returning. +// +// To understand why we need to await the awaited ACK/NACK after resending the +// queue, it ensures that we don't end up in a situation where we resend the +// queue over and over again due to latency and delayed NACKs by the other +// party. +// +// Consider the following scenario: +// 1. +// Alice sends packets 1, 2, 3 & 4 to Bob. +// 2. +// Bob receives packets 1, 2, 3 & 4, and sends back the respective ACKs. +// 3. +// Alice receives ACKs for packets 1 & 2, but due to latency the ACKs for +// packets 3 & 4 are delayed and aren't received until Alice resend timeout +// has passed, which leads to Alice resending packets 3 & 4. Alice will after +// that receive the delayed ACKs for packets 3 & 4, but will consider that as +// the ACKs for the resent packets, and not the original packets which they were +// actually sent for. If we didn't wait after resending the queue, Alice would +// then proceed to send more packets (5 & 6). +// 4. +// When Bob receives the resent packets 3 & 4, Bob will respond with NACK 5. Due +// to latency, the packets 5 & 6 that Alice sent in step (3) above will then be +// received by Bob, and be processed as the correct response to the NACK 5. Bob +// will after that await packet 7. +// 5. +// Alice will receive the NACK 5, and now resend packets 5 & 6. But as Bob is +// now awaiting packet 7, this send will lead to a NACK 7. But due to latency, +// if Alice doesn't wait resending the queue, Alice will proceed to send new +// packet(s) before receiving the NACK 7. +// 6. +// This resend loop would continue indefinitely, so we need to ensure that Alice +// waits after she has resent the queue, to ensure that she doesn't proceed to +// send new packets before she is sure that both parties are in sync. +// +// To ensure that we are in sync, after we have resent the queue, we will await +// that we either: +// 1. Receive a NACK for the sequence number succeeding the last packet in the +// resent queue i.e. in step (3) above, that would be NACK 5. +// OR +// 2. Receive an ACK for the last packet in the resent queue i.e. in step (3) +// above, that would be ACK 4. After we receive the expected ACK, we will then +// wait for the duration of the resend timeout before continuing. The reason why +// we wait for the resend timeout before continuing, is that the ACKs we are +// getting after a resend, could be delayed ACKs for the original packets we +// sent, and not ACKs for the resent packets. In step (3) above, the ACKs for +// packets 3 & 4 that Alice received were delayed ACKs for the original packets. +// If Alice would have immediately continued to send new packets (5 & 6) after +// receiving the ACK 4, she would have then received the NACK 5 from Bob which +// was the actual response to the resent queue. But as Alice had already +// continued to send packets 5 & 6 when receiving the NACK 5, the resend queue +// response to that NACK would cause the resend loop to continue indefinitely. +// +// When either of the 2 conditions above are met, we will consider both parties +// to be in sync, and we can proceed to send new packets. +func (q *queue) resend(resendTimeout time.Duration) error { + if time.Since(q.lastResend) < resendTimeout { log.Tracef("Resent the queue recently.") return nil @@ -91,6 +218,8 @@ func (q *queue) resend(cb func(packet *PacketData) error) error { q.lastResend = time.Now() + q.awaitingCatchUpMu.Lock() + q.baseMtx.RLock() base := q.sequenceBase q.baseMtx.RUnlock() @@ -100,35 +229,158 @@ func (q *queue) resend(cb func(packet *PacketData) error) error { q.topMtx.RUnlock() if base == top { + q.awaitingCatchUpMu.Unlock() + return nil } + if q.noPingPackets(base, top) { + q.awaitedACK = (q.cfg.s + top - 1) % q.cfg.s + q.awaitedNACK = top + + log.Tracef("Set awaitedACK to %d & awaitedNACK to %d", + q.awaitedACK, q.awaitedNACK) + + q.awaitingCatchUp = true + + // Create a new instance of the caughtUpSignal channel, so that + // the new resend catchup cycle uses a unique channel. + q.caughtUpSignal = make(chan struct{}, 1) + } else { + log.Tracef("Won't catch up due to ping packet(s) only") + + q.awaitingCatchUp = false + } + log.Tracef("Resending the queue") for base != top { packet := q.content[base] - if err := cb(packet); err != nil { + if err := q.cfg.sendPkt(packet); err != nil { + q.awaitingCatchUpMu.Unlock() + return err } - base = (base + 1) % q.s + + base = (base + 1) % q.cfg.s log.Tracef("Resent %d", packet.Seq) } + if !q.awaitingCatchUp { + q.awaitingCatchUpMu.Unlock() + + return nil + } + + // We hold the awaitingCatchUpMu mutex for the duration of the resend to + // ensure that we don't process the delayed ACKs for the packets we are + // resending, during the resend. If that would happen, we would start + // the "proceedAfterTime" callback timeout while still resending + // packets. That could mean that the NACK that the resent packets will + // trigger, might be received after the timeout has passed. That would + // cause the resend loop to trigger once more. + q.awaitingCatchUpMu.Unlock() + + // Then await until we know that both parties are in sync. + q.awaitCatchUp(resendTimeout) + return nil } -// processACK processes an incoming ACK of a given sequence number. -func (q *queue) processACK(seq uint8) bool { +// awaitCatchUp awaits that we either receive the awaited ACK or NACK signal +// before returning. If we don't receive the awaited ACK or NACK signal before +// 3X the resend timeout, the function will also return. +// See the docs for the resend function for more details on why we need to await +// the awaited ACK or NACK signal. +// +//nolint:cyclop +func (q *queue) awaitCatchUp(resendTimeout time.Duration) { + ticker := time.NewTimer(resendTimeout * awaitingTimeoutMultiplier) + defer ticker.Stop() + + log.Tracef("Awaiting catchup after resending the queue") + +catchupLoop: + for { + select { + case <-q.quit: + return + case <-q.awaitedACKSignal: + log.Tracef("Got awaitedACKSignal") + + break catchupLoop + case <-q.awaitedNACKSignal: + log.Tracef("Got awaitedNACKSignal") + + break catchupLoop + case <-ticker.C: + log.Tracef("Timed out while awaiting catchup") + + q.awaitingCatchUpMu.Lock() + q.awaitingCatchUp = false + + // Drain both the ACK & NACK signal channels. + select { + case <-q.awaitedACKSignal: + default: + } + + select { + case <-q.awaitedNACKSignal: + default: + } + + q.awaitingCatchUpMu.Unlock() + + break catchupLoop + default: + continue + } + } + + // Send a caughtUpSignal to indicate that we have caught up after + // resending the queue. + q.caughtUpSignal <- struct{}{} +} + +// noPingPackets returns true if all the packets for the given packet indices +// are not Ping packets. +func (q *queue) noPingPackets(base, top uint8) bool { + for base != top { + packet := q.content[base] + + if packet.IsPing { + return false + } + + base = (base + 1) % q.cfg.s + } + return true +} + +// processACK processes an incoming ACK of a given sequence number. +func (q *queue) processACK(seq uint8, resendTimeout time.Duration) bool { // If our queue is empty, an ACK should not have any effect. if q.size() == 0 { - log.Tracef("Received ack %d, but queue is empty. Ignoring.", - seq) + log.Tracef("Received ack %d, but queue is empty. Ignoring.", seq) return false } + // If we are awaiting a catch up, and the ACK is the awaited ACK, we + // start the proceedAfterTime callback, which will send an + // awaitedACKSignal if we're still awaiting the resend catch up when + // the callback is executed. + q.awaitingCatchUpMu.RLock() + if seq == q.awaitedACK && q.awaitingCatchUp { + log.Tracef("Got awaited ACK") + + q.proceedAfterTime(q.caughtUpSignal, resendTimeout) + } + q.awaitingCatchUpMu.RUnlock() + q.baseMtx.Lock() defer q.baseMtx.Unlock() @@ -139,7 +391,7 @@ func (q *queue) processACK(seq uint8) bool { // has decreased. log.Tracef("Received correct ack %d", seq) - q.sequenceBase = (q.sequenceBase + 1) % q.s + q.sequenceBase = (q.sequenceBase + 1) % q.cfg.s // We did receive an ACK. return true @@ -160,7 +412,7 @@ func (q *queue) processACK(seq uint8) bool { if containsSequence(q.sequenceBase, q.sequenceTop, seq) { log.Tracef("Sequence %d is in the queue. Bump the base.", seq) - q.sequenceBase = (seq + 1) % q.s + q.sequenceBase = (seq + 1) % q.cfg.s // We did receive an ACK. return true @@ -172,6 +424,9 @@ func (q *queue) processACK(seq uint8) bool { // processNACK processes an incoming NACK of a given sequence number. func (q *queue) processNACK(seq uint8) (bool, bool) { + q.awaitingCatchUpMu.Lock() + defer q.awaitingCatchUpMu.Unlock() + q.baseMtx.Lock() defer q.baseMtx.Unlock() @@ -180,17 +435,39 @@ func (q *queue) processNACK(seq uint8) (bool, bool) { log.Tracef("Received NACK %d", seq) - // If the NACK is the same as sequenceTop, it probably means that queue + if q.awaitingCatchUp && seq == q.awaitedNACK { + log.Tracef("Sending awaitedNACKSignal") + q.awaitedNACKSignal <- struct{}{} + + q.awaitingCatchUp = false + + // In case the awaitedNACK is the same as sequenceTop, we can + // bump the base to be equal to sequenceTop, without triggering + // a new resend. + if seq == q.sequenceTop { + q.sequenceBase = q.sequenceTop + } + + // If we receive the awaited NACK, we shouldn't trigger a new + // resend, as we can now proceed to send new packets. + return false, false + } + + // If the NACK is the same as sequenceTop, and we weren't awaiting this + // NACK as part of the resend catch up, it probably means that queue // was sent successfully, but we just missed the necessary ACKs. So we - // can empty the queue here by bumping the base and we dont need to + // can empty the queue here by bumping the base and we don't need to // trigger a resend. if seq == q.sequenceTop { q.sequenceBase = q.sequenceTop - return true, false + + return false, false } // Is the NACKed sequence even in our queue? if !containsSequence(q.sequenceBase, q.sequenceTop, seq) { + log.Tracef("NACK seq %d is not in the queue. Ignoring.", seq) + return false, false } @@ -206,6 +483,48 @@ func (q *queue) processNACK(seq uint8) (bool, bool) { return true, bumped } +// proceedAfterTime will wait for the resendTimeout and then send an +// awaitedACKSignal, if we're still awaiting the resend catch up. +func (q *queue) proceedAfterTime(caughtUpSignal chan struct{}, + resendTimeout time.Duration) { + + processAwaitedACK := func() { + log.Tracef("Executing proceedAfterTime") + + // We want to ensure that this function only sends an + // awaitedACKSignal if proceedAfterTime is executed while we're + // still awaiting the resend catch up in the resend catch up + // cycle it was initiated for. Therefore we check that we + // haven't already caught up. + select { + case <-caughtUpSignal: + log.Tracef("Already caught up") + + return + default: + } + + q.awaitingCatchUpMu.Lock() + + if q.awaitingCatchUp { + log.Tracef("Sending awaitedACKSignal") + q.awaitedACKSignal <- struct{}{} + + q.awaitingCatchUp = false + } else { + log.Tracef("Ending proceedAfterTime without any action") + } + + q.awaitingCatchUpMu.Unlock() + } + + // We await for the duration of the resendTimeout before executing the + // proceedAfterTime callback, as that's the time we'd expect it to take + // for the other party to respond with a NACK, if the resent last packet + // in the queue would lead to a NACK. + time.AfterFunc(resendTimeout, processAwaitedACK) +} + // containsSequence is used to determine if a number, seq, is between two other // numbers, base and top, where all the numbers lie in a finite field (modulo // space) s. diff --git a/gbn/queue_test.go b/gbn/queue_test.go index b91fc141..8aa1ec2c 100644 --- a/gbn/queue_test.go +++ b/gbn/queue_test.go @@ -7,15 +7,20 @@ import ( ) func TestQueueSize(t *testing.T) { - q := newQueue(4, 0) + queue := newQueue(&queueConfig{ + s: 4, + sendPkt: func(packet *PacketData) error { + return nil + }, + }) - require.Equal(t, uint8(0), q.size()) + require.Equal(t, uint8(0), queue.size()) - q.sequenceBase = 2 - q.sequenceTop = 3 - require.Equal(t, uint8(1), q.size()) + queue.sequenceBase = 2 + queue.sequenceTop = 3 + require.Equal(t, uint8(1), queue.size()) - q.sequenceBase = 3 - q.sequenceTop = 2 - require.Equal(t, uint8(3), q.size()) + queue.sequenceBase = 3 + queue.sequenceTop = 2 + require.Equal(t, uint8(3), queue.size()) }