diff --git a/ext/updater.go b/ext/updater.go index 45c1cdd5..b028245d 100644 --- a/ext/updater.go +++ b/ext/updater.go @@ -300,10 +300,6 @@ func (data botData) stop() { // This does NOT set the webhook on telegram - this should be done by the caller. // The opts parameter allows for specifying various webhook settings. func (u *Updater) StartWebhook(b *gotgbot.Bot, urlPath string, opts WebhookOpts) error { - if u.webhookServer != nil { - return ErrExpectedEmptyServer - } - err := u.AddWebhook(b, urlPath, opts) if err != nil { return fmt.Errorf("failed to add webhook: %w", err) @@ -325,7 +321,7 @@ func (u *Updater) AddWebhook(b *gotgbot.Bot, urlPath string, opts WebhookOpts) e err := u.botMapping.addBot(botData{ bot: b, updateChan: updateChan, - urlPath: urlPath, + urlPath: strings.TrimPrefix(urlPath, "/"), webhookSecret: opts.SecretToken, }) if err != nil { @@ -375,6 +371,10 @@ func (u *Updater) StartServer(opts WebhookOpts) error { return fmt.Errorf("failed to listen on %s:%s: %w", opts.ListenNet, opts.ListenAddr, err) } + if u.webhookServer != nil { + return ErrExpectedEmptyServer + } + u.webhookServer = &http.Server{ Handler: u.GetHandlerFunc("/"), ReadTimeout: opts.ReadTimeout, diff --git a/ext/updater_test.go b/ext/updater_test.go index 72f68ceb..49e3957d 100644 --- a/ext/updater_test.go +++ b/ext/updater_test.go @@ -1,6 +1,7 @@ package ext_test import ( + "context" "encoding/json" "errors" "fmt" @@ -85,6 +86,134 @@ func TestUpdaterDisallowsEmptyWebhooks(t *testing.T) { } } +func TestUpdater_GetHandlerFunc(t *testing.T) { + b := &gotgbot.Bot{ + Token: "SOME_TOKEN", + BotClient: &gotgbot.BaseBotClient{}, + } + + type args struct { + urlPath string + opts ext.WebhookOpts + httpResponse int + handlerPrefix string + requestPath string // Should start with '/' + headers map[string]string + } + tests := []struct { + name string + args args + }{ + { + name: "simple path", + args: args{ + urlPath: "123:hello", + httpResponse: http.StatusOK, + handlerPrefix: "/", + requestPath: "/123:hello", + }, + }, { + name: "slash prefixed path", + args: args{ + urlPath: "/123:hello", + httpResponse: http.StatusOK, + handlerPrefix: "/", + requestPath: "/123:hello", + }, + }, { + name: "using subpath", + args: args{ + urlPath: "123:hello", + httpResponse: http.StatusOK, + handlerPrefix: "/test/", + requestPath: "/test/123:hello", + }, + }, { + name: "unknown path", + args: args{ + urlPath: "123:hello", + httpResponse: http.StatusNotFound, + handlerPrefix: "/", + requestPath: "/this-path-doesnt-exist", + }, + }, { + name: "missing secret token", + args: args{ + urlPath: "123:hello", + opts: ext.WebhookOpts{ + SecretToken: "secret", + }, + httpResponse: http.StatusUnauthorized, + handlerPrefix: "/", + requestPath: "/123:hello", + }, + }, { + name: "matching secret token", + args: args{ + urlPath: "123:hello", + opts: ext.WebhookOpts{ + SecretToken: "secret", + }, + httpResponse: http.StatusOK, + handlerPrefix: "/", + requestPath: "/123:hello", + headers: map[string]string{ + "X-Telegram-Bot-Api-Secret-Token": "secret", + }, + }, + }, { + name: "invalid secret token", + args: args{ + urlPath: "123:hello", + opts: ext.WebhookOpts{ + SecretToken: "secret", + }, + httpResponse: http.StatusUnauthorized, + handlerPrefix: "/", + requestPath: "/123:hello", + headers: map[string]string{ + "X-Telegram-Bot-Api-Secret-Token": "wrong", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := ext.NewDispatcher(nil) + u := ext.NewUpdater(d, nil) + + if err := u.AddWebhook(b, tt.args.urlPath, tt.args.opts); err != nil { + t.Errorf("failed to add webhook: %v", err) + return + } + + s := httptest.NewServer(u.GetHandlerFunc(tt.args.handlerPrefix)) + url := s.URL + tt.args.requestPath + // We pass {} to satisfy JSON handling + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, strings.NewReader("{}")) + if err != nil { + t.Errorf("Failed to build request, should have worked: %v", err.Error()) + return + } + + for k, v := range tt.args.headers { + req.Header.Set(k, v) + } + + r, err := s.Client().Do(req) + if err != nil { + t.Fatal() + } + + defer r.Body.Close() + if r.StatusCode != tt.args.httpResponse { + t.Errorf("Expected code %d, got %d", tt.args.httpResponse, r.StatusCode) + return + } + }) + } +} + func TestUpdaterAllowsWebhookDeletion(t *testing.T) { server := basicTestServer(t, map[string]string{ "getUpdates": `{}`, @@ -116,6 +245,12 @@ func TestUpdaterAllowsWebhookDeletion(t *testing.T) { t.Errorf("failed to start long poll on first bot: %v", err) return } + + err = u.Stop() + if err != nil { + t.Errorf("failed to stop updater: %v", err) + return + } } func TestUpdaterSupportsTwoPollingBots(t *testing.T) { @@ -165,6 +300,12 @@ func TestUpdaterSupportsTwoPollingBots(t *testing.T) { t.Errorf("should be able to add two different polling bots") return } + + err = u.Stop() + if err != nil { + t.Errorf("failed to stop updater: %v", err) + return + } } func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) { @@ -207,6 +348,12 @@ func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) { t.Errorf("should have failed to start the same long poll twice, but didnt") return } + + err = u.Stop() + if err != nil { + t.Errorf("failed to stop updater: %v", err) + return + } } func TestUpdaterSupportsLongPollReAdding(t *testing.T) { @@ -252,6 +399,12 @@ func TestUpdaterSupportsLongPollReAdding(t *testing.T) { t.Errorf("Failed to re-start a previously removed bot: %v", err) return } + + err = u.Stop() + if err != nil { + t.Errorf("failed to stop updater: %v", err) + return + } } func basicTestServer(t *testing.T, methods map[string]string) *httptest.Server {