diff --git a/ext/botmapping.go b/ext/botmapping.go index 45d0355f..b785508c 100644 --- a/ext/botmapping.go +++ b/ext/botmapping.go @@ -163,9 +163,15 @@ func (m *botMapping) getHandlerFunc(prefix string) func(writer http.ResponseWrit w.WriteHeader(http.StatusNotFound) return } + b.updateWriterControl.Add(1) defer b.updateWriterControl.Done() + if b.shouldStopUpdates() { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + headerSecret := r.Header.Get("X-Telegram-Bot-Api-Secret-Token") if b.webhookSecret != "" && b.webhookSecret != headerSecret { // Drop any updates from invalid secret tokens. @@ -184,10 +190,6 @@ func (m *botMapping) getHandlerFunc(prefix string) func(writer http.ResponseWrit return } - if b.isUpdateChannelStopped() { - return - } - b.updateChan <- bytes } } @@ -213,7 +215,7 @@ func (b *botData) stop() { close(b.updateChan) } -func (b *botData) isUpdateChannelStopped() bool { +func (b *botData) shouldStopUpdates() bool { select { case <-b.stopUpdates: // if anything comes in on the closing channel, we know the channel is closed. diff --git a/ext/botmapping_test.go b/ext/botmapping_test.go index 64075482..bcbd06ea 100644 --- a/ext/botmapping_test.go +++ b/ext/botmapping_test.go @@ -89,13 +89,13 @@ func Test_botData_isUpdateChannelStopped(t *testing.T) { t.Errorf("bot with token %s should not have failed to be added", b.Token) return } - if bData.isUpdateChannelStopped() { + if bData.shouldStopUpdates() { t.Errorf("bot with token %s should not be stopped yet", b.Token) return } bData.stop() - if !bData.isUpdateChannelStopped() { + if !bData.shouldStopUpdates() { t.Errorf("bot with token %s should be stopped", b.Token) return } diff --git a/ext/updater.go b/ext/updater.go index 267ab074..e664617a 100644 --- a/ext/updater.go +++ b/ext/updater.go @@ -173,6 +173,11 @@ func (u *Updater) pollingLoop(bData *botData, opts *gotgbot.RequestOpts, v map[s defer bData.updateWriterControl.Done() for { + // Check if updater loop has been terminated. + if bData.shouldStopUpdates() { + return + } + // Manually craft the getUpdate calls to improve memory management, reduce json parsing overheads, and // unnecessary reallocation of url.Values in the polling loop. r, err := bData.bot.Request("getUpdates", v, nil, opts) @@ -219,10 +224,6 @@ func (u *Updater) pollingLoop(bData *botData, opts *gotgbot.RequestOpts, v map[s v["offset"] = strconv.FormatInt(lastUpdate.UpdateId+1, 10) - if bData.isUpdateChannelStopped() { - return - } - for _, updData := range rawUpdates { temp := updData // use new mem address to avoid loop conflicts bData.updateChan <- temp @@ -240,6 +241,9 @@ func (u *Updater) Idle() { } // Stop stops the current updater and dispatcher instances. +// +// When using long polling, Stop() will wait for the getUpdates call to return, which may cause a delay due to the +// request timeout. func (u *Updater) Stop() error { // Stop any running servers. if u.webhookServer != nil { diff --git a/ext/updater_test.go b/ext/updater_test.go index 359e5e81..bd1d9581 100644 --- a/ext/updater_test.go +++ b/ext/updater_test.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -98,8 +99,14 @@ func concurrentTest(t *testing.T) { t.Parallel() delay := time.Second - server := basicTestServer(t, map[string]testEndpoint{ - "getUpdates": {delay: delay, reply: `{"ok": true, "result": [{"message": {"text": "stop"}}]}`}, + server := basicTestServer(t, map[string]*testEndpoint{ + "getUpdates": { + delay: delay, + replies: []string{ + `{"ok": true, "result": [{"message": {"text": "stop"}}]}`, + }, + reply: `{"ok": true, "result": []}`, + }, "deleteWebhook": {reply: `{"ok": true, "result": true}`}, }) defer server.Close() @@ -290,7 +297,7 @@ func TestUpdater_GetHandlerFunc(t *testing.T) { } func TestUpdaterAllowsWebhookDeletion(t *testing.T) { - server := basicTestServer(t, map[string]testEndpoint{ + server := basicTestServer(t, map[string]*testEndpoint{ "getUpdates": {reply: `{"ok": true}`}, "deleteWebhook": {reply: `{"ok": true, "result": true}`}, }) @@ -329,7 +336,7 @@ func TestUpdaterAllowsWebhookDeletion(t *testing.T) { } func TestUpdaterSupportsTwoPollingBots(t *testing.T) { - server := basicTestServer(t, map[string]testEndpoint{ + server := basicTestServer(t, map[string]*testEndpoint{ "getUpdates": {reply: `{"ok": true, "result": []}`}, }) defer server.Close() @@ -384,7 +391,7 @@ func TestUpdaterSupportsTwoPollingBots(t *testing.T) { } func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) { - server := basicTestServer(t, map[string]testEndpoint{ + server := basicTestServer(t, map[string]*testEndpoint{ "getUpdates": {reply: `{"ok": true, "result": []}`}, }) defer server.Close() @@ -432,7 +439,7 @@ func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) { } func TestUpdaterSupportsLongPollReAdding(t *testing.T) { - server := basicTestServer(t, map[string]testEndpoint{ + server := basicTestServer(t, map[string]*testEndpoint{ "getUpdates": {reply: `{"ok": true, "result": []}`}, }) defer server.Close() @@ -484,10 +491,14 @@ func TestUpdaterSupportsLongPollReAdding(t *testing.T) { type testEndpoint struct { delay time.Duration + // Will reply these until we run out of replies, at which point we repeat "reply" + replies []string + idx atomic.Int32 + // default reply reply string } -func basicTestServer(t *testing.T, methods map[string]testEndpoint) *httptest.Server { +func basicTestServer(t *testing.T, methods map[string]*testEndpoint) *httptest.Server { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pathItems := strings.Split(r.URL.Path, "/") lastItem := pathItems[len(pathItems)-1] @@ -498,7 +509,12 @@ func basicTestServer(t *testing.T, methods map[string]testEndpoint) *httptest.Se if out.delay != 0 { time.Sleep(out.delay) } - fmt.Fprint(w, out.reply) + count := int(out.idx.Add(1) - 1) + if len(out.replies) != 0 && len(out.replies) > count { + fmt.Fprint(w, out.replies[count]) + } else { + fmt.Fprint(w, out.reply) + } return } diff --git a/samples/echoMultiBot/main.go b/samples/echoMultiBot/main.go index c84a6a40..a57d9d8a 100644 --- a/samples/echoMultiBot/main.go +++ b/samples/echoMultiBot/main.go @@ -86,7 +86,10 @@ func main() { // If we get here, the updater.Idle() has ended. // This means that updater.Stop() has been called, stopping all bots gracefully. - log.Println("Updater is no longer idling; all bots have been stopped gracefully.") + log.Println("Updater is no longer idling; all bots have been stopped gracefully. Exiting in 1s.") + + // We sleep one last second to allow for the "stopall" goroutine to send the shutdown message. + time.Sleep(time.Second) } // startLongPollingBots demonstrates how to start multiple bots with long-polling. @@ -159,11 +162,14 @@ func stop(b *gotgbot.Bot, ctx *ext.Context, updater *ext.Updater) error { return fmt.Errorf("failed to echo message: %w", err) } - if !updater.StopBot(b.Token) { - ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Unable to find bot %d; was it already stopped?", b.Id), nil) - return nil - } + go func() { + if !updater.StopBot(b.Token) { + ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Unable to find bot %d; was it already stopped?", b.Id), nil) + return + } + ctx.EffectiveMessage.Reply(b, "Stopped @"+b.Username, nil) + }() return nil } @@ -181,6 +187,7 @@ func stopAll(b *gotgbot.Bot, ctx *ext.Context, updater *ext.Updater) error { ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Failed to stop updater: %s", err.Error()), nil) return } + ctx.EffectiveMessage.Reply(b, "All bots have been stopped.", nil) }() return nil