Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve concurrent safety #127

Merged
merged 6 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions ext/botmapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,19 @@ import (
type botData struct {
// bot represents the bot for which this data is relevant.
bot *gotgbot.Bot

// updateChan represents the incoming updates channel.
updateChan chan json.RawMessage
// polling allows us to close the polling loop.
polling chan struct{}
// updateWriterControl is used to count the number of current writers on the update channel.
// This is required to ensure that we can safely close the channel, and thus stop processing incoming updates.
// While this remains non-zero, it is unsafe to close the update channel.
updateWriterControl *sync.WaitGroup
// stopUpdates allows us to close the stopUpdates loop.
stopUpdates chan struct{}

// urlPath defines the incoming webhook URL path for this bot.
urlPath string
// webhookSecret stores the webhook secret for this bot
// webhookSecret stores the webhook secret for this bot.
webhookSecret string
}

Expand All @@ -45,7 +51,8 @@ var ErrBotAlreadyExists = errors.New("bot already exists in bot mapping")
var ErrBotUrlPathAlreadyExists = errors.New("url path already exists in bot mapping")

// addBot Adds a new bot to the botMapping structure.
func (m *botMapping) addBot(bData botData) error {
// Pass an empty urlPath/webhookSecret if using polling instead of webhooks.
func (m *botMapping) addBot(b *gotgbot.Bot, urlPath string, webhookSecret string) (*botData, error) {
m.mux.Lock()
defer m.mux.Unlock()

Expand All @@ -56,16 +63,26 @@ func (m *botMapping) addBot(bData botData) error {
m.urlMapping = make(map[string]string)
}

if _, ok := m.mapping[bData.bot.Token]; ok {
return ErrBotAlreadyExists
if _, ok := m.mapping[b.Token]; ok {
return nil, ErrBotAlreadyExists
}
if _, ok := m.urlMapping[bData.urlPath]; bData.urlPath != "" && ok {
return ErrBotUrlPathAlreadyExists

if _, ok := m.urlMapping[urlPath]; urlPath != "" && ok {
return nil, ErrBotUrlPathAlreadyExists
}

bData := botData{
bot: b,
updateChan: make(chan json.RawMessage),
stopUpdates: make(chan struct{}),
updateWriterControl: &sync.WaitGroup{},
urlPath: urlPath,
webhookSecret: webhookSecret,
}

m.mapping[bData.bot.Token] = bData
m.urlMapping[bData.urlPath] = bData.bot.Token
return nil
return &bData, nil
}

func (m *botMapping) removeBot(token string) (botData, bool) {
Expand Down Expand Up @@ -143,6 +160,8 @@ func (m *botMapping) getHandlerFunc(prefix string) func(writer http.ResponseWrit
w.WriteHeader(http.StatusNotFound)
return
}
b.updateWriterControl.Add(1)
defer b.updateWriterControl.Done()

headerSecret := r.Header.Get("X-Telegram-Bot-Api-Secret-Token")
if b.webhookSecret != "" && b.webhookSecret != headerSecret {
Expand All @@ -161,6 +180,11 @@ func (m *botMapping) getHandlerFunc(prefix string) func(writer http.ResponseWrit
w.WriteHeader(http.StatusInternalServerError)
return
}

if b.isUpdateChannelStopped() {
return
}

b.updateChan <- bytes
}
}
Expand All @@ -172,3 +196,27 @@ func (m *botMapping) logf(format string, args ...interface{}) {
log.Printf(format, args...)
}
}

func (b *botData) stop() {
// Close stopUpdates loops first, to ensure any updates currently being polled have the time to be sent to the updateChan.
if b.stopUpdates != nil {
close(b.stopUpdates)
}

// Wait for all writers to finish writing to the updateChannel
b.updateWriterControl.Wait()

// Then, close the updates channel.
close(b.updateChan)
}

func (b *botData) isUpdateChannelStopped() bool {
select {
case <-b.stopUpdates:
// if anything comes in on the closing channel, we know the channel is closed.
return true
default:
// otherwise, continue as usual
return false
}
}
49 changes: 32 additions & 17 deletions ext/botmapping_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ext

import (
"encoding/json"
"testing"

"github.com/PaulSonOfLars/gotgbot/v2"
Expand All @@ -15,16 +14,11 @@ func Test_botMapping(t *testing.T) {
BotClient: &gotgbot.BaseBotClient{},
}

updateChan := make(chan json.RawMessage)
pollChan := make(chan struct{})

var origBdata *botData
t.Run("addBot", func(t *testing.T) {
// check that bots can be added fine
err := bm.addBot(botData{
bot: b,
updateChan: updateChan,
polling: pollChan,
})
var err error
origBdata, err = bm.addBot(b, "", "")
if err != nil {
t.Errorf("expected to be able to add a new bot fine: %s", err.Error())
t.FailNow()
Expand All @@ -37,11 +31,7 @@ func Test_botMapping(t *testing.T) {

t.Run("doubleAdd", func(t *testing.T) {
// Adding the same bot twice should fail
err := bm.addBot(botData{
bot: b,
updateChan: updateChan,
polling: pollChan,
})
_, err := bm.addBot(b, "", "")
if err == nil {
t.Errorf("adding the same bot twice should throw an error")
t.FailNow()
Expand All @@ -59,11 +49,11 @@ func Test_botMapping(t *testing.T) {
t.Errorf("failed to get bot with token %s", b.Token)
t.FailNow()
}
if bdata.polling != pollChan {
t.Errorf("polling channel was not the same")
if bdata.stopUpdates != origBdata.stopUpdates {
t.Errorf("stopUpdates channel was not the same")
t.FailNow()
}
if bdata.updateChan != updateChan {
if bdata.updateChan != origBdata.updateChan {
t.Errorf("update channel was not the same")
t.FailNow()
}
Expand All @@ -85,3 +75,28 @@ func Test_botMapping(t *testing.T) {
})

}

func Test_botData_isUpdateChannelStopped(t *testing.T) {
bm := botMapping{}
b := &gotgbot.Bot{
User: gotgbot.User{},
Token: "SOME_TOKEN",
BotClient: &gotgbot.BaseBotClient{},
}

bData, err := bm.addBot(b, "", "")
if err != nil {
t.Errorf("bot with token %s should not have failed to be added", b.Token)
return
}
if bData.isUpdateChannelStopped() {
t.Errorf("bot with token %s should not be stopped yet", b.Token)
return
}

bData.stop()
if !bData.isUpdateChannelStopped() {
t.Errorf("bot with token %s should be stopped", b.Token)
return
}
}
57 changes: 16 additions & 41 deletions ext/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,37 +157,25 @@ func (u *Updater) StartPolling(b *gotgbot.Bot, opts *PollingOpts) error {
}
}

updateChan := make(chan json.RawMessage)
pollChan := make(chan struct{})

err := u.botMapping.addBot(botData{
bot: b,
updateChan: updateChan,
polling: pollChan,
})
bData, err := u.botMapping.addBot(b, "", "")
if err != nil {
return fmt.Errorf("failed to add bot with long polling: %w", err)
}

go u.Dispatcher.Start(b, updateChan)
go u.pollingLoop(b, reqOpts, pollChan, updateChan, v)
go u.Dispatcher.Start(b, bData.updateChan)
go u.pollingLoop(bData, reqOpts, v)

return nil
}

func (u *Updater) pollingLoop(b *gotgbot.Bot, opts *gotgbot.RequestOpts, stopPolling <-chan struct{}, updateChan chan<- json.RawMessage, v map[string]string) {
for {
select {
case <-stopPolling:
// if anything comes in, stop polling.
return
default:
// otherwise, continue as usual
}
func (u *Updater) pollingLoop(bData *botData, opts *gotgbot.RequestOpts, v map[string]string) {
bData.updateWriterControl.Add(1)
defer bData.updateWriterControl.Done()

for {
// Manually craft the getUpdate calls to improve memory management, reduce json parsing overheads, and
// unnecessary reallocation of url.Values in the polling loop.
r, err := b.Request("getUpdates", v, nil, opts)
r, err := bData.bot.Request("getUpdates", v, nil, opts)
if err != nil {
if u.UnhandledErrFunc != nil {
u.UnhandledErrFunc(err)
Expand Down Expand Up @@ -230,9 +218,14 @@ func (u *Updater) pollingLoop(b *gotgbot.Bot, opts *gotgbot.RequestOpts, stopPol
}

v["offset"] = strconv.FormatInt(lastUpdate.UpdateId+1, 10)

if bData.isUpdateChannelStopped() {
return
}

for _, updData := range rawUpdates {
temp := updData // use new mem address to avoid loop conflicts
updateChan <- temp
bData.updateChan <- temp
}
}
}
Expand Down Expand Up @@ -285,17 +278,6 @@ func (u *Updater) StopAllBots() {
}
}

func (data botData) stop() {
// Close polling loops first, to ensure any updates currently being polled have the time to be sent to the
// updateChan.
if data.polling != nil {
close(data.polling)
}

// Then, close the updates channel.
close(data.updateChan)
}

// StartWebhook starts the webhook server for a single bot instance.
// This does NOT set the webhook on telegram - this should be done by the caller.
// The opts parameter allows for specifying various webhook settings.
Expand All @@ -316,20 +298,13 @@ func (u *Updater) AddWebhook(b *gotgbot.Bot, urlPath string, opts WebhookOpts) e
return fmt.Errorf("expected a non-empty url path: %w", ErrEmptyPath)
}

updateChan := make(chan json.RawMessage)

err := u.botMapping.addBot(botData{
bot: b,
updateChan: updateChan,
urlPath: strings.TrimPrefix(urlPath, "/"),
webhookSecret: opts.SecretToken,
})
bData, err := u.botMapping.addBot(b, strings.TrimPrefix(urlPath, "/"), opts.SecretToken)
if err != nil {
return fmt.Errorf("failed to add webhook for bot: %w", err)
}

// Webhook has been added; relevant dispatcher should also be started.
go u.Dispatcher.Start(b, updateChan)
go u.Dispatcher.Start(b, bData.updateChan)
return nil
}

Expand Down
Loading
Loading