Skip to content

Commit

Permalink
Improve ConversationHandler logic, and expose custom KeyStrategy func…
Browse files Browse the repository at this point in the history
…tions (#166)

* Allow for defining custom key state functions for improved conversation handling

* Allow for defining custom key state functions for improved conversation handling

* add comment

* nextHandler should not ignore emptyKey errors

* add tests to cover the case of any updates which might break a conversation

* Add a conversation-wide filter to control which updates get processed

* Add a conversation handler filter
  • Loading branch information
PaulSonOfLars authored Jul 1, 2024
1 parent a8e2c08 commit 43b5186
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 44 deletions.
21 changes: 20 additions & 1 deletion ext/handlers/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ import (
// TODO: Add a "block" option to force linear processing. Also a "waiting" state to handle blocked handlers.
// TODO: Allow for timeouts (and a "timeout" state to handle that)

// ConversationFilter is much wider than regular filters, because it allows for any kind of update; we may want
// messages, commands, callbacks, etc.
type ConversationFilter func(ctx *ext.Context) bool

// The Conversation handler is an advanced handler which allows for running a sequence of commands in a stateful manner.
// An example of this flow can be found at t.me/Botfather; upon receiving the "/newbot" command, the user is asked for
// the name of their bot, which is sent as a separate message.
Expand All @@ -33,6 +37,10 @@ type Conversation struct {
Fallbacks []ext.Handler
// If True, a user can restart the conversation by hitting one of the entry points.
AllowReEntry bool
// Filter allows users to set a conversation-wide filter to any incoming updates. This can be useful to only target
// one specific chat, or to avoid unwanted updates which may interfere with the conversation key strategy
// (eg polls).
Filter ConversationFilter
}

type ConversationOpts struct {
Expand All @@ -45,6 +53,10 @@ type ConversationOpts struct {
AllowReEntry bool
// StateStorage is responsible for storing all running conversations.
StateStorage conversation.Storage
// Filter allows users to set a conversation-wide filter to any incoming updates. This can be useful to only target
// one specific chat, or to avoid unwanted updates which may interfere with the conversation key strategy
// (eg polls).
Filter ConversationFilter
}

func NewConversation(entryPoints []ext.Handler, states map[string][]ext.Handler, opts *ConversationOpts) Conversation {
Expand All @@ -59,6 +71,7 @@ func NewConversation(entryPoints []ext.Handler, states map[string][]ext.Handler,
c.Exits = opts.Exits
c.Fallbacks = opts.Fallbacks
c.AllowReEntry = opts.AllowReEntry
c.Filter = opts.Filter

// If no StateStorage is specified, we should keep the default.
if opts.StateStorage != nil {
Expand Down Expand Up @@ -169,10 +182,16 @@ func (c Conversation) Name() string {
// getNextHandler goes through all the handlers in the conversation, until it finds a handler that matches.
// If no matching handler is found, returns nil.
func (c Conversation) getNextHandler(b *gotgbot.Bot, ctx *ext.Context) (ext.Handler, error) {
// If the user has defined a filter, and this filter does NOT return true, then we do NOT want to consider this
// update for the conversation.
if c.Filter != nil && !c.Filter(ctx) {
return nil, nil
}

// Check if a conversation has already started for this user.
currState, err := c.StateStorage.Get(ctx)
if err != nil {
if errors.Is(err, conversation.KeyNotFound) {
if errors.Is(err, conversation.ErrKeyNotFound) {
// If this is an unknown conversation key, then we know this is a new conversation, so we check all
// entrypoints.
return checkHandlerList(c.EntryPoints, b, ctx), nil
Expand Down
22 changes: 0 additions & 22 deletions ext/handlers/conversation/common.go

This file was deleted.

21 changes: 15 additions & 6 deletions ext/handlers/conversation/in_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/PaulSonOfLars/gotgbot/v2/ext"
)

var KeyNotFound = errors.New("conversation key not found")
var ErrKeyNotFound = errors.New("conversation key not found")

// InMemoryStorage is a thread-safe in-memory implementation of the Storage interface.
type InMemoryStorage struct {
Expand All @@ -28,24 +28,30 @@ func NewInMemoryStorage(strategy KeyStrategy) *InMemoryStorage {
}

func (c *InMemoryStorage) Get(ctx *ext.Context) (*State, error) {
key := StateKey(ctx, c.keyStrategy)
key, err := StateKey(ctx, c.keyStrategy)
if err != nil {
return nil, err
}

c.lock.RLock()
defer c.lock.RUnlock()

if c.conversations == nil {
return nil, KeyNotFound
return nil, ErrKeyNotFound
}

s, ok := c.conversations[key]
if !ok {
return nil, KeyNotFound
return nil, ErrKeyNotFound
}
return &s, nil
}

func (c *InMemoryStorage) Set(ctx *ext.Context, state State) error {
key := StateKey(ctx, c.keyStrategy)
key, err := StateKey(ctx, c.keyStrategy)
if err != nil {
return err
}

c.lock.Lock()
defer c.lock.Unlock()
Expand All @@ -59,7 +65,10 @@ func (c *InMemoryStorage) Set(ctx *ext.Context, state State) error {
}

func (c *InMemoryStorage) Delete(ctx *ext.Context) error {
key := StateKey(ctx, c.keyStrategy)
key, err := StateKey(ctx, c.keyStrategy)
if err != nil {
return err
}

c.lock.Lock()
defer c.lock.Unlock()
Expand Down
62 changes: 52 additions & 10 deletions ext/handlers/conversation/key_strategies.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,55 @@
package conversation

type KeyStrategy int64

// Note: If you add a new keystrategy here, make sure to add it to the getStateKey method!
const (
// KeyStrategySenderAndChat ensures that each sender get a unique conversation in each chats.
KeyStrategySenderAndChat KeyStrategy = iota
// KeyStrategySender gives a unique conversation to each sender, but that conversation is available in all chats.
KeyStrategySender
// KeyStrategyChat gives a unique conversation to each chat, which all senders can interact in together.
KeyStrategyChat
import (
"errors"
"fmt"
"strconv"

"github.com/PaulSonOfLars/gotgbot/v2/ext"
)

var ErrEmptyKey = errors.New("empty conversation key")

// KeyStrategy is the function used to obtain the current key in the ongoing conversation.
//
// Use one of the existing keys, or define your own if you need external data (eg a DB or other state).
type KeyStrategy func(ctx *ext.Context) (string, error)

var (
// Ensure key strategy methods match the function signatures.
_ KeyStrategy = KeyStrategyChat
_ KeyStrategy = KeyStrategySender
_ KeyStrategy = KeyStrategySenderAndChat
)

// KeyStrategySenderAndChat ensures that each sender get a unique conversation, even in different chats.
func KeyStrategySenderAndChat(ctx *ext.Context) (string, error) {
if ctx.EffectiveSender == nil || ctx.EffectiveChat == nil {
return "", fmt.Errorf("missing sender or chat fields: %w", ErrEmptyKey)
}
return fmt.Sprintf("%d/%d", ctx.EffectiveSender.Id(), ctx.EffectiveChat.Id), nil
}

// KeyStrategySender gives a unique conversation to each sender, and that single conversation is available in all chats.
func KeyStrategySender(ctx *ext.Context) (string, error) {
if ctx.EffectiveSender == nil {
return "", fmt.Errorf("missing sender field: %w", ErrEmptyKey)
}
return strconv.FormatInt(ctx.EffectiveSender.Id(), 10), nil
}

// KeyStrategyChat gives a unique conversation to each chat, which all senders can interact in together.
func KeyStrategyChat(ctx *ext.Context) (string, error) {
if ctx.EffectiveChat == nil {
return "", fmt.Errorf("missing chat field: %w", ErrEmptyKey)
}
return strconv.FormatInt(ctx.EffectiveChat.Id, 10), nil
}

// StateKey provides a sane default for handling incoming updates.
func StateKey(ctx *ext.Context, strategy KeyStrategy) (string, error) {
if strategy == nil {
return KeyStrategySenderAndChat(ctx)
}
return strategy(ctx)
}
54 changes: 49 additions & 5 deletions ext/handlers/conversation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handlers_test

import (
"errors"
"math/rand"
"testing"

"github.com/PaulSonOfLars/gotgbot/v2"
Expand Down Expand Up @@ -312,6 +313,49 @@ func TestNestedConversation(t *testing.T) {
checkExpectedState(t, &conv, textMessage, "")
}

func TestEmptyKeyConversation(t *testing.T) {
b := NewTestBot()

// Dummy conversation; not important.
conv := handlers.NewConversation(
[]ext.Handler{handlers.NewCommand("start", func(b *gotgbot.Bot, ctx *ext.Context) error {
return handlers.NextConversationState("next")
})},
map[string][]ext.Handler{},
&handlers.ConversationOpts{
// This strategy will fail when we don't have a chat/user; eg, a poll update, which has neither.
StateStorage: conversation.NewInMemoryStorage(conversation.KeyStrategySenderAndChat),
},
)

// Run an empty
pollUpd := ext.NewContext(&gotgbot.Update{
UpdateId: rand.Int63(), // should this be consistent?
Poll: &gotgbot.Poll{
Id: "some_id",
Question: "Some question",
Type: "quiz",
AllowsMultipleAnswers: false,
CorrectOptionId: 0,
Explanation: "",
},
}, nil)

if err := conv.HandleUpdate(b, pollUpd); !errors.Is(err, conversation.ErrEmptyKey) {
t.Fatal("poll update should have caused an error in the conversation handler")
}

conv.Filter = func(ctx *ext.Context) bool {
// These are prerequisites for the SenderAndChat strategy; if we dont have them, skip!
return ctx.EffectiveChat != nil && ctx.EffectiveSender != nil
}

if err := conv.HandleUpdate(b, pollUpd); err != nil {
t.Fatal("poll update should NOT have caused an error, as it is now filtered out")
}

}

// runHandler ensures that the incoming update will trigger the conversation.
func runHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, message *ext.Context, currentState string, nextState string) {
willRunHandler(t, b, conv, message, currentState)
Expand All @@ -335,12 +379,12 @@ func willRunHandler(t *testing.T, b *gotgbot.Bot, conv *handlers.Conversation, m

func checkExpectedState(t *testing.T, conv *handlers.Conversation, message *ext.Context, nextState string) {
currentState, err := conv.StateStorage.Get(message)
if nextState == "" {
if !errors.Is(err, conversation.KeyNotFound) {
t.Fatalf("expected not to have a conversation, but got currentState: %s", currentState)
if err != nil {
if nextState == "" && errors.Is(err, conversation.ErrKeyNotFound) {
// Success! No next state, because we don't have a "next" key.
return
}
} else if err != nil {
t.Fatalf("unexpected error while checking the current currentState of the conversation")
t.Fatalf("unexpected error while checking the current currentState of the conversation: %s", err.Error())
} else if currentState == nil || currentState.Key != nextState {
t.Fatalf("expected the conversation to be at '%s', was '%s'", nextState, currentState)
}
Expand Down

0 comments on commit 43b5186

Please sign in to comment.