Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend ext.Context to store bot information #198

Merged
merged 4 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@ import (
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
)

//go:generate go run ./scripts/generate

var (
ErrNilBotClient = errors.New("nil BotClient")
ErrInvalidTokenFormat = errors.New("invalid token format")
)

// Bot is the default Bot struct used to send and receive messages to the telegram API.
type Bot struct {
// Token stores the bot's secret token obtained from t.me/BotFather, and used to interact with telegram's API.
Expand Down Expand Up @@ -76,6 +83,24 @@ func NewBot(token string, opts *BotOpts) (*Bot, error) {
return nil, fmt.Errorf("failed to check bot token: %w", err)
}
b.User = *botUser
} else {
// If token checks are disabled, we populate the bot's ID from the token.
split := strings.Split(token, ":")
if len(split) != 2 {
return nil, fmt.Errorf("%w: expected '123:abcd', got %s", ErrInvalidTokenFormat, token)
}

id, err := strconv.ParseInt(split[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse bot ID from token: %w", err)
}
b.User = User{
Id: id,
IsBot: true,
// We mark these fields as missing so we can know why they're not available
FirstName: "<missing>",
Username: "<missing>",
}
}

return &b, nil
Expand All @@ -89,8 +114,6 @@ func (bot *Bot) UseMiddleware(mw func(client BotClient) BotClient) *Bot {
return bot
}

var ErrNilBotClient = errors.New("nil BotClient")

func (bot *Bot) Request(method string, params map[string]string, data map[string]FileReader, opts *RequestOpts) (json.RawMessage, error) {
return bot.RequestWithContext(context.Background(), method, params, data, opts)
}
Expand Down
8 changes: 6 additions & 2 deletions ext/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
type Context struct {
// gotgbot.Update is inlined so that we can access all fields immediately if necessary.
*gotgbot.Update
// Bot represents gotgbot.User behind the Bot that received this update, so we can keep track of update ownership.
// Note: this information may be incomplete in the case where token validation is disabled.
Bot gotgbot.User
// Data represents update-local storage.
// This can be used to pass data across handlers - for example, to cache operations relevant to the current update,
// such as admin checks.
Expand All @@ -35,9 +38,9 @@ type Context struct {
EffectiveSender *gotgbot.Sender
}

// NewContext populates a context with the relevant fields from the current update.
// NewContext populates a context with the relevant fields from the current bot and update.
// It takes a data field in the case where custom data needs to be passed.
func NewContext(update *gotgbot.Update, data map[string]interface{}) *Context {
func NewContext(b *gotgbot.Bot, update *gotgbot.Update, data map[string]interface{}) *Context {
var msg *gotgbot.Message
var chat *gotgbot.Chat
var user *gotgbot.User
Expand Down Expand Up @@ -162,6 +165,7 @@ func NewContext(update *gotgbot.Update, data map[string]interface{}) *Context {

return &Context{
Update: update,
Bot: b.User,
Data: data,
EffectiveMessage: msg,
EffectiveChat: chat,
Expand Down
2 changes: 1 addition & 1 deletion ext/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func (d *Dispatcher) processRawUpdate(b *gotgbot.Bot, r json.RawMessage) error {
// ProcessUpdate iterates over the list of groups to execute the matching handlers.
// This is also where we recover from any panics that are thrown by user code, to avoid taking down the bot.
func (d *Dispatcher) ProcessUpdate(b *gotgbot.Bot, u *gotgbot.Update, data map[string]interface{}) (err error) {
ctx := NewContext(u, data)
ctx := NewContext(b, u, data)

defer func() {
if r := recover(); r != nil {
Expand Down
2 changes: 1 addition & 1 deletion ext/dispatcher_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestDispatcher(t *testing.T) {
}

t.Log("Processing one update...")
err := d.ProcessUpdate(nil, &gotgbot.Update{
err := d.ProcessUpdate(&gotgbot.Bot{}, &gotgbot.Update{
Message: &gotgbot.Message{Text: "test text"},
}, nil)
if err != nil {
Expand Down
14 changes: 7 additions & 7 deletions ext/handlers/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func NewTestBot() *gotgbot.Bot {
return &gotgbot.Bot{
Token: "use-me",
User: gotgbot.User{
Id: 0,
Id: rand.Int63(),
IsBot: false,
FirstName: "gobot",
LastName: "",
Expand All @@ -33,13 +33,13 @@ func NewTestBot() *gotgbot.Bot {
}
}

func NewMessage(userId int64, chatId int64, message string) *ext.Context {
return newMessage(userId, chatId, message, nil)
func NewMessage(b *gotgbot.Bot, userId int64, chatId int64, message string) *ext.Context {
return newMessage(b, userId, chatId, message, nil)
}

func NewCommandMessage(userId int64, chatId int64, command string, args []string) *ext.Context {
func NewCommandMessage(b *gotgbot.Bot, userId int64, chatId int64, command string, args []string) *ext.Context {
msg, ents := buildCommand(command, args)
return newMessage(userId, chatId, msg, ents)
return newMessage(b, userId, chatId, msg, ents)
}

func buildCommand(cmd string, args []string) (string, []gotgbot.MessageEntity) {
Expand All @@ -53,13 +53,13 @@ func buildCommand(cmd string, args []string) (string, []gotgbot.MessageEntity) {
}
}

func newMessage(userId int64, chatId int64, message string, entities []gotgbot.MessageEntity) *ext.Context {
func newMessage(b *gotgbot.Bot, userId int64, chatId int64, message string, entities []gotgbot.MessageEntity) *ext.Context {
chatType := "supergroup"
if userId == chatId {
chatType = "private"
}

return ext.NewContext(&gotgbot.Update{
return ext.NewContext(b, &gotgbot.Update{
UpdateId: rand.Int63(), // should this be consistent?
Message: &gotgbot.Message{
MessageId: rand.Int63(), // should this be consistent?
Expand Down
7 changes: 3 additions & 4 deletions ext/handlers/conversation/key_strategies.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package conversation
import (
"errors"
"fmt"
"strconv"

"github.com/PaulSonOfLars/gotgbot/v2/ext"
)
Expand All @@ -27,23 +26,23 @@ func KeyStrategySenderAndChat(ctx *ext.Context) (string, error) {
if ctx.EffectiveSender == nil || ctx.EffectiveChat == nil {
return "", fmt.Errorf("missing sender or chat fields: %w", ErrEmptyKey)
}
return fmt.Sprintf("%d/%d", ctx.EffectiveSender.Id(), ctx.EffectiveChat.Id), nil
return fmt.Sprintf("%d/%d/%d", ctx.Bot.Id, ctx.EffectiveSender.Id(), ctx.EffectiveChat.Id), nil
}

// KeyStrategySender gives a unique conversation to each sender, and that single conversation is available in all chats.
func KeyStrategySender(ctx *ext.Context) (string, error) {
if ctx.EffectiveSender == nil {
return "", fmt.Errorf("missing sender field: %w", ErrEmptyKey)
}
return strconv.FormatInt(ctx.EffectiveSender.Id(), 10), nil
return fmt.Sprintf("%d/%d", ctx.Bot.Id, ctx.EffectiveSender.Id()), nil
}

// KeyStrategyChat gives a unique conversation to each chat, which all senders can interact in together.
func KeyStrategyChat(ctx *ext.Context) (string, error) {
if ctx.EffectiveChat == nil {
return "", fmt.Errorf("missing chat field: %w", ErrEmptyKey)
}
return strconv.FormatInt(ctx.EffectiveChat.Id, 10), nil
return fmt.Sprintf("%d/%d", ctx.Bot.Id, ctx.EffectiveChat.Id), nil
}

// StateKey provides a sane default for handling incoming updates.
Expand Down
45 changes: 28 additions & 17 deletions ext/handlers/conversation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ func TestBasicConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if !started {
t.Fatalf("expected the entrypoint handler to have run")
}

// Emulate sending the "message" text, triggering the internal handler (and causing it to "end").
textMessage := NewMessage(userId, chatId, "message")
textMessage := NewMessage(b, userId, chatId, "message")
runHandler(t, b, &conv, textMessage, nextStep, "")
if !ended {
t.Fatalf("expected the internal handler to have run")
Expand Down Expand Up @@ -79,8 +79,8 @@ func TestBasicKeyedConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startFromUserOne := NewCommandMessage(userIdOne, chatId, "start", []string{})
messageFromTwo := NewMessage(userIdTwo, chatId, "message")
startFromUserOne := NewCommandMessage(b, userIdOne, chatId, "start", []string{})
messageFromTwo := NewMessage(b, userIdTwo, chatId, "message")

runHandler(t, b, &conv, startFromUserOne, "", nextStep)

Expand All @@ -89,6 +89,11 @@ func TestBasicKeyedConversation(t *testing.T) {

// But user two doesnt exist
checkExpectedState(t, &conv, messageFromTwo, "")

b2 := NewTestBot()
messageTo2 := NewMessage(b2, userIdOne, chatId, "message")
// And bot two hasn't changed either
checkExpectedState(t, &conv, messageTo2, "")
}

func TestBasicConversationExit(t *testing.T) {
Expand Down Expand Up @@ -121,14 +126,14 @@ func TestBasicConversationExit(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint, and starting the conversation.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if !started {
t.Fatalf("expected the entrypoint handler to have run")
}

// Emulate sending the "cancel" command, triggering the exitpoint, and immediately ending the conversation.
cancelCommand := NewCommandMessage(userId, chatId, "cancel", []string{})
cancelCommand := NewCommandMessage(b, userId, chatId, "cancel", []string{})
runHandler(t, b, &conv, cancelCommand, nextStep, "")
if !ended {
t.Fatalf("expected the cancel command to have run")
Expand All @@ -138,7 +143,7 @@ func TestBasicConversationExit(t *testing.T) {
checkExpectedState(t, &conv, cancelCommand, "")

// Emulate sending the "message" text, which now should not interact with the conversation.
textMessage := NewMessage(userId, chatId, "message")
textMessage := NewMessage(b, userId, chatId, "message")
if conv.CheckUpdate(b, textMessage) {
t.Fatalf("did not expect the internal handler to run")
}
Expand Down Expand Up @@ -177,14 +182,14 @@ func TestFallbackConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if !started {
t.Fatalf("expected the entrypoint handler to have run")
}

// Emulate sending the "cancel" command, triggering the fallback handler (and causing it to "end").
cancelCommand := NewCommandMessage(userId, chatId, "cancel", []string{})
cancelCommand := NewCommandMessage(b, userId, chatId, "cancel", []string{})
runHandler(t, b, &conv, cancelCommand, nextStep, "")
if !fallback {
t.Fatalf("expected the fallback handler to have run")
Expand Down Expand Up @@ -220,14 +225,14 @@ func TestReEntryConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if startCount != 1 {
t.Fatalf("expected the entrypoint handler to have run")
}

// Send a message which matches both the entrypoint, and the "nextStep" state.
cancelCommand := NewCommandMessage(userId, chatId, "start", []string{"message"})
cancelCommand := NewCommandMessage(b, userId, chatId, "start", []string{"message"})
runHandler(t, b, &conv, cancelCommand, nextStep, nextStep) // Should hit
if startCount != 2 {
t.Fatalf("expected the entrypoint handler to have run a second time")
Expand Down Expand Up @@ -285,20 +290,20 @@ func TestNestedConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
start := NewCommandMessage(userId, chatId, startCmd, []string{})
start := NewCommandMessage(b, userId, chatId, startCmd, []string{})
runHandler(t, b, &conv, start, "", firstStep)

// Emulate sending the "message" text, triggering the internal handler (and causing it to "end").
textMessage := NewMessage(userId, chatId, messageText)
textMessage := NewMessage(b, userId, chatId, messageText)
runHandler(t, b, &conv, textMessage, firstStep, secondStep)

// Emulate sending the "nested_start" command, triggering the entrypoint of the nested conversation.
nestedStart := NewCommandMessage(userId, chatId, nestedStartCmd, []string{})
nestedStart := NewCommandMessage(b, userId, chatId, nestedStartCmd, []string{})
willRunHandler(t, b, &nestedConv, nestedStart, "")
runHandler(t, b, &conv, nestedStart, secondStep, secondStep)

// Emulate sending the "nested_start" command, triggering the entrypoint of the nested conversation.
nestedFinish := NewMessage(userId, chatId, finishNestedText)
nestedFinish := NewMessage(b, userId, chatId, finishNestedText)
willRunHandler(t, b, &nestedConv, nestedFinish, nestedStep)
runHandler(t, b, &conv, nestedFinish, secondStep, thirdStep)

Expand All @@ -307,7 +312,7 @@ func TestNestedConversation(t *testing.T) {
t.Log("Nested conversation finished")

// Emulate sending the "message" text, triggering the internal handler (and causing it to "end").
finish := NewMessage(userId, chatId, finishText)
finish := NewMessage(b, userId, chatId, finishText)
runHandler(t, b, &conv, finish, thirdStep, "")

checkExpectedState(t, &conv, textMessage, "")
Expand All @@ -329,7 +334,7 @@ func TestEmptyKeyConversation(t *testing.T) {
)

// Run an empty
pollUpd := ext.NewContext(&gotgbot.Update{
pollUpd := ext.NewContext(b, &gotgbot.Update{
UpdateId: rand.Int63(), // should this be consistent?
Poll: &gotgbot.Poll{
Id: "some_id",
Expand Down Expand Up @@ -358,6 +363,8 @@ func TestEmptyKeyConversation(t *testing.T) {

// runHandler ensures that the incoming update will trigger the conversation.
func runHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, message *ext.Context, currentState string, nextState string) {
t.Helper()

willRunHandler(t, b, conv, message, currentState)
if err := conv.HandleUpdate(b, message); err != nil {
t.Fatalf("unexpected error from handler: %s", err.Error())
Expand All @@ -368,6 +375,8 @@ func runHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, messa

// willRunHandler ensures that the incoming update will trigger the conversation.
func willRunHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, message *ext.Context, expectedState string) {
t.Helper()

t.Logf("conv %p: checking message for %d in %d with text: %s", conv, message.EffectiveSender.Id(), message.EffectiveChat.Id, message.Message.Text)

checkExpectedState(t, conv, message, expectedState)
Expand All @@ -378,6 +387,8 @@ func willRunHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, m
}

func checkExpectedState(t *testing.T, conv *handlers.Conversation, message *ext.Context, nextState string) {
t.Helper()

currentState, err := conv.StateStorage.Get(message)
if err != nil {
if nextState == "" && errors.Is(err, conversation.ErrKeyNotFound) {
Expand Down