Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix webhook handler prefix issue #126

Merged
merged 5 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ext/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,6 @@ func (data botData) stop() {
// This does NOT set the webhook on telegram - this should be done by the caller.
// The opts parameter allows for specifying various webhook settings.
func (u *Updater) StartWebhook(b *gotgbot.Bot, urlPath string, opts WebhookOpts) error {
if u.webhookServer != nil {
return ErrExpectedEmptyServer
}

err := u.AddWebhook(b, urlPath, opts)
if err != nil {
return fmt.Errorf("failed to add webhook: %w", err)
Expand All @@ -325,7 +321,7 @@ func (u *Updater) AddWebhook(b *gotgbot.Bot, urlPath string, opts WebhookOpts) e
err := u.botMapping.addBot(botData{
bot: b,
updateChan: updateChan,
urlPath: urlPath,
urlPath: strings.TrimPrefix(urlPath, "/"),
webhookSecret: opts.SecretToken,
})
if err != nil {
Expand Down Expand Up @@ -375,6 +371,10 @@ func (u *Updater) StartServer(opts WebhookOpts) error {
return fmt.Errorf("failed to listen on %s:%s: %w", opts.ListenNet, opts.ListenAddr, err)
}

if u.webhookServer != nil {
return ErrExpectedEmptyServer
}

u.webhookServer = &http.Server{
Handler: u.GetHandlerFunc("/"),
ReadTimeout: opts.ReadTimeout,
Expand Down
153 changes: 153 additions & 0 deletions ext/updater_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ext_test

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -85,6 +86,134 @@ func TestUpdaterDisallowsEmptyWebhooks(t *testing.T) {
}
}

func TestUpdater_GetHandlerFunc(t *testing.T) {
b := &gotgbot.Bot{
Token: "SOME_TOKEN",
BotClient: &gotgbot.BaseBotClient{},
}

type args struct {
urlPath string
opts ext.WebhookOpts
httpResponse int
handlerPrefix string
requestPath string // Should start with '/'
headers map[string]string
}
tests := []struct {
name string
args args
}{
{
name: "simple path",
args: args{
urlPath: "123:hello",
httpResponse: http.StatusOK,
handlerPrefix: "/",
requestPath: "/123:hello",
},
}, {
name: "slash prefixed path",
args: args{
urlPath: "/123:hello",
httpResponse: http.StatusOK,
handlerPrefix: "/",
requestPath: "/123:hello",
},
}, {
name: "using subpath",
args: args{
urlPath: "123:hello",
httpResponse: http.StatusOK,
handlerPrefix: "/test/",
requestPath: "/test/123:hello",
},
}, {
name: "unknown path",
args: args{
urlPath: "123:hello",
httpResponse: http.StatusNotFound,
handlerPrefix: "/",
requestPath: "/this-path-doesnt-exist",
},
}, {
name: "missing secret token",
args: args{
urlPath: "123:hello",
opts: ext.WebhookOpts{
SecretToken: "secret",
},
httpResponse: http.StatusUnauthorized,
handlerPrefix: "/",
requestPath: "/123:hello",
},
}, {
name: "matching secret token",
args: args{
urlPath: "123:hello",
opts: ext.WebhookOpts{
SecretToken: "secret",
},
httpResponse: http.StatusOK,
handlerPrefix: "/",
requestPath: "/123:hello",
headers: map[string]string{
"X-Telegram-Bot-Api-Secret-Token": "secret",
},
},
}, {
name: "invalid secret token",
args: args{
urlPath: "123:hello",
opts: ext.WebhookOpts{
SecretToken: "secret",
},
httpResponse: http.StatusUnauthorized,
handlerPrefix: "/",
requestPath: "/123:hello",
headers: map[string]string{
"X-Telegram-Bot-Api-Secret-Token": "wrong",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := ext.NewDispatcher(nil)
u := ext.NewUpdater(d, nil)

if err := u.AddWebhook(b, tt.args.urlPath, tt.args.opts); err != nil {
t.Errorf("failed to add webhook: %v", err)
return
}

s := httptest.NewServer(u.GetHandlerFunc(tt.args.handlerPrefix))
url := s.URL + tt.args.requestPath
// We pass {} to satisfy JSON handling
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, strings.NewReader("{}"))
if err != nil {
t.Errorf("Failed to build request, should have worked: %v", err.Error())
return
}

for k, v := range tt.args.headers {
req.Header.Set(k, v)
}

r, err := s.Client().Do(req)
if err != nil {
t.Fatal()
}

defer r.Body.Close()
if r.StatusCode != tt.args.httpResponse {
t.Errorf("Expected code %d, got %d", tt.args.httpResponse, r.StatusCode)
return
}
})
}
}

func TestUpdaterAllowsWebhookDeletion(t *testing.T) {
server := basicTestServer(t, map[string]string{
"getUpdates": `{}`,
Expand Down Expand Up @@ -116,6 +245,12 @@ func TestUpdaterAllowsWebhookDeletion(t *testing.T) {
t.Errorf("failed to start long poll on first bot: %v", err)
return
}

err = u.Stop()
if err != nil {
t.Errorf("failed to stop updater: %v", err)
return
}
}

func TestUpdaterSupportsTwoPollingBots(t *testing.T) {
Expand Down Expand Up @@ -165,6 +300,12 @@ func TestUpdaterSupportsTwoPollingBots(t *testing.T) {
t.Errorf("should be able to add two different polling bots")
return
}

err = u.Stop()
if err != nil {
t.Errorf("failed to stop updater: %v", err)
return
}
}

func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) {
Expand Down Expand Up @@ -207,6 +348,12 @@ func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) {
t.Errorf("should have failed to start the same long poll twice, but didnt")
return
}

err = u.Stop()
if err != nil {
t.Errorf("failed to stop updater: %v", err)
return
}
}

func TestUpdaterSupportsLongPollReAdding(t *testing.T) {
Expand Down Expand Up @@ -252,6 +399,12 @@ func TestUpdaterSupportsLongPollReAdding(t *testing.T) {
t.Errorf("Failed to re-start a previously removed bot: %v", err)
return
}

err = u.Stop()
if err != nil {
t.Errorf("failed to stop updater: %v", err)
return
}
}

func basicTestServer(t *testing.T, methods map[string]string) *httptest.Server {
Expand Down
Loading