Skip to content

Commit

Permalink
Add a NamedHandler, as well as the ability to remove handlers by thei…
Browse files Browse the repository at this point in the history
…r name. (#47)

* Add handler removals

* Add some dispatcher tests to confirm removals work fine

* fix lint

* Add warnings that methods are not thread safe

* Fix lints

* Add threadsafe mutexes to protect against undefined behaviour from concurrent modification

* Improve tests, and fix the error that they uncovered

* pre-allocate handlers when using getGroups

* Add additional tests for handler mappings

* remove invalid thread safety comments on group add/remove

* remove redundant dummy redeclaration
  • Loading branch information
PaulSonOfLars authored Sep 2, 2023
1 parent eab7bce commit f87f4f7
Show file tree
Hide file tree
Showing 7 changed files with 459 additions and 16 deletions.
3 changes: 3 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ issues:
- gosec
text: "G404:" # warning about insecure math/rand. We dont care about this in tests!
path: "\\w*_test.go"
- linters:
- gosimple
text: "S1023:" # allow redundant return statements. They can be nice for readability.

# Enable default excludes, for common sense values.
exclude-use-default: true
Expand Down
3 changes: 2 additions & 1 deletion ext/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

type DummyHandler struct {
F func(b *gotgbot.Bot, ctx *Context) error
N string
}

func (d DummyHandler) CheckUpdate(b *gotgbot.Bot, ctx *Context) bool {
Expand All @@ -17,7 +18,7 @@ func (d DummyHandler) HandleUpdate(b *gotgbot.Bot, ctx *Context) error {
}

func (d DummyHandler) Name() string {
return "dummy"
return "dummy" + d.N
}

func (u *Updater) InjectUpdate(token string, upd gotgbot.Update) error {
Expand Down
35 changes: 20 additions & 15 deletions ext/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"log"
"runtime/debug"
"sort"
"strings"
"sync"

Expand Down Expand Up @@ -71,10 +70,8 @@ type Dispatcher struct {
// If nil, logging is done via the log package's standard logger.
ErrorLog *log.Logger

// handlerGroups represents the list of available handler groups, numerically sorted.
handlerGroups []int
// handlers represents all available handles, split into groups (see handlerGroups).
handlers map[int][]Handler
// handlers represents all available handlers.
handlers handlerMappings

// limiter is how we limit the maximum number of goroutines for handling updates.
// if nil, this is a limitless dispatcher.
Expand Down Expand Up @@ -152,7 +149,7 @@ func NewDispatcher(opts *DispatcherOpts) *Dispatcher {
Panic: panicHandler,
UnhandledErrFunc: unhandledErrFunc,
ErrorLog: errLog,
handlers: make(map[int][]Handler),
handlers: handlerMappings{},
limiter: limiter,
waitGroup: sync.WaitGroup{},
}
Expand Down Expand Up @@ -228,13 +225,21 @@ func (d *Dispatcher) AddHandler(handler Handler) {
}

// AddHandlerToGroup adds a handler to a specific group; lowest number will be processed first.
func (d *Dispatcher) AddHandlerToGroup(handler Handler, group int) {
currHandlers, ok := d.handlers[group]
if !ok {
d.handlerGroups = append(d.handlerGroups, group)
sort.Ints(d.handlerGroups)
}
d.handlers[group] = append(currHandlers, handler)
func (d *Dispatcher) AddHandlerToGroup(h Handler, group int) {
d.handlers.add(h, group)
}

// RemoveHandlerFromGroup removes a handler by name from the specified group.
// If multiple handlers have the same name, only the first one is removed.
// Returns true if the handler was successfully removed.
func (d *Dispatcher) RemoveHandlerFromGroup(handlerName string, group int) bool {
return d.handlers.remove(handlerName, group)
}

// RemoveGroup removes an entire group from the dispatcher's processing.
// If group can't be found, this is a noop.
func (d *Dispatcher) RemoveGroup(group int) bool {
return d.handlers.removeGroup(group)
}

// processRawUpdate takes a JSON update to be unmarshalled and processed by Dispatcher.ProcessUpdate.
Expand Down Expand Up @@ -274,8 +279,8 @@ func (d *Dispatcher) ProcessUpdate(b *gotgbot.Bot, u *gotgbot.Update, data map[s
}

func (d *Dispatcher) iterateOverHandlerGroups(b *gotgbot.Bot, ctx *Context) error {
for _, groupNum := range d.handlerGroups {
for _, handler := range d.handlers[groupNum] {
for _, groups := range d.handlers.getGroups() {
for _, handler := range groups {
if !handler.CheckUpdate(b, ctx) {
// Handler filter doesn't match this update; continue.
continue
Expand Down
182 changes: 182 additions & 0 deletions ext/dispatcher_ext_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package ext_test

import (
"sort"
"testing"

"github.com/PaulSonOfLars/gotgbot/v2"
"github.com/PaulSonOfLars/gotgbot/v2/ext"
"github.com/PaulSonOfLars/gotgbot/v2/ext/handlers"
"github.com/PaulSonOfLars/gotgbot/v2/ext/handlers/filters/message"
)

func TestDispatcher(t *testing.T) {
type testHandler struct {
group int
shouldRun bool
returnVal error
}

for name, testParams := range map[string]struct {
handlers []testHandler
numMatches int
}{
"one group two handlers": {
handlers: []testHandler{
{
group: 0,
shouldRun: true,
returnVal: nil,
}, {
group: 0,
shouldRun: false, // same group, so doesnt run
returnVal: nil,
},
},
numMatches: 1,
},
"two handlers two groups": {
handlers: []testHandler{
{
group: 0,
shouldRun: true,
returnVal: nil,
}, {
group: 1,
shouldRun: true, // second group, so also runs
returnVal: nil,
},
},
numMatches: 2,
},
"end groups": {
handlers: []testHandler{
{
group: 0,
shouldRun: true,
returnVal: ext.EndGroups,
}, {
group: 1,
shouldRun: false, // ended, so second group doesnt run
returnVal: nil,
},
},
numMatches: 1,
},
"continue groups": {
handlers: []testHandler{
{
group: 0,
shouldRun: true,
returnVal: ext.ContinueGroups,
}, {
group: 0,
shouldRun: true, // continued, so second item in same group runs
returnVal: nil,
},
},
numMatches: 2,
},
} {
name, testParams := name, testParams

t.Run(name, func(t *testing.T) {
d := ext.NewDispatcher(nil)
var events []int
for idx, h := range testParams.handlers {
idx, h := idx, h

t.Logf("Loading handler %d in group %d", idx, h.group)
d.AddHandlerToGroup(handlers.NewMessage(message.All, func(b *gotgbot.Bot, ctx *ext.Context) error {
if !h.shouldRun {
t.Errorf("handler %d in group %d should not have run", idx, h.group)
t.FailNow()
}

t.Logf("handler %d in group %d has run, as expected", idx, h.group)
events = append(events, idx)
return h.returnVal
}), h.group)
}

t.Log("Processing one update...")
err := d.ProcessUpdate(nil, &gotgbot.Update{
Message: &gotgbot.Message{Text: "test text"},
}, nil)
if err != nil {
t.Errorf("Unexpected error while processing updates: %s", err.Error())
}

// ensure events handled in order
if !sort.IntsAreSorted(events) {
t.Errorf("order of events is not sorted: %v", events)
}
if len(events) != testParams.numMatches {
t.Errorf("got %d matches, expected %d ", len(events), testParams.numMatches)
}
})
}
}

func TestDispatcher_RemoveHandlerFromGroup(t *testing.T) {
d := ext.NewDispatcher(nil)

const removeMe = "remove_me"
const group = 0

d.AddHandlerToGroup(handlers.NewNamedhandler(removeMe, handlers.NewMessage(message.All, nil)), group)

if found := d.RemoveHandlerFromGroup(removeMe, group); !found {
t.Errorf("RemoveHandlerFromGroup() = %v, want true", found)
}
}

func TestDispatcher_RemoveOneHandlerFromGroup(t *testing.T) {
d := ext.NewDispatcher(nil)

const removeMe = "remove_me"
const group = 0

// Load handler twice.
d.AddHandlerToGroup(handlers.NewNamedhandler(removeMe, handlers.NewMessage(message.All, nil)), group)
d.AddHandlerToGroup(handlers.NewNamedhandler(removeMe, handlers.NewMessage(message.All, nil)), group)

// Remove handler twice.
if found := d.RemoveHandlerFromGroup(removeMe, group); !found {
t.Errorf("RemoveHandlerFromGroup() = %v, want true", found)
}
if found := d.RemoveHandlerFromGroup(removeMe, group); !found {
t.Errorf("RemoveHandlerFromGroup() = %v, want true", found)
}
// fail! only 2 in there.
if found := d.RemoveHandlerFromGroup(removeMe, group); found {
t.Errorf("RemoveHandlerFromGroup() = %v, want false", found)
}
}

func TestDispatcher_RemoveHandlerNonExistingHandlerFromGroup(t *testing.T) {
d := ext.NewDispatcher(nil)

const keepMe = "keep_me"
const removeMe = "remove_me"
const group = 0

d.AddHandlerToGroup(handlers.NewNamedhandler(keepMe, handlers.NewMessage(message.All, nil)), group)

if found := d.RemoveHandlerFromGroup(removeMe, group); found {
t.Errorf("RemoveHandlerFromGroup() = %v, want false", found)
}
}

func TestDispatcher_RemoveHandlerHandlerFromNonExistingGroup(t *testing.T) {
d := ext.NewDispatcher(nil)

const removeMe = "remove_me"
const group = 0
const wrongGroup = 1
d.AddHandlerToGroup(handlers.NewNamedhandler(removeMe, handlers.NewMessage(message.All, nil)), group)

if found := d.RemoveHandlerFromGroup(removeMe, wrongGroup); found {
t.Errorf("RemoveHandlerFromGroup() = %v, want false", found)
}
}
112 changes: 112 additions & 0 deletions ext/handler_mapping.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package ext

import (
"sort"
"sync"
)

type handlerMappings struct {
// mutex is used to ensure everything threadsafe.
mutex sync.RWMutex

// handlerGroups represents the list of available handler groups, numerically sorted.
handlerGroups []int
// handlers represents all available handlers, split into groups (see handlerGroups).
handlers map[int][]Handler
}

func (m *handlerMappings) add(h Handler, group int) {
m.mutex.Lock()
defer m.mutex.Unlock()

if m.handlers == nil {
m.handlers = map[int][]Handler{}
}
currHandlers, ok := m.handlers[group]
if !ok {
m.handlerGroups = append(m.handlerGroups, group)
sort.Ints(m.handlerGroups)
}
m.handlers[group] = append(currHandlers, h)
}

func (m *handlerMappings) remove(name string, group int) bool {
m.mutex.Lock()
defer m.mutex.Unlock()

currHandlers, ok := m.handlers[group]
if !ok {
// group does not exist; removal failed.
return false
}

for idx, handler := range currHandlers {
if handler.Name() != name {
continue
}

// Only one item left, so just delete the group entirely.
if len(currHandlers) == 1 {
// get index of the current group to remove it from the list of handlergroups
gIdx := getIndex(group, m.handlerGroups)
if gIdx != -1 {
m.handlerGroups = append(m.handlerGroups[:gIdx], m.handlerGroups[gIdx+1:]...)
}
delete(m.handlers, group)
return true
}

// Make sure to copy the handler list to ensure we don't change the values of the underlying arrays, which
// could cause slice access issues when used concurrently.
newHandlers := make([]Handler, len(m.handlers[group]))
copy(newHandlers, m.handlers[group])

m.handlers[group] = append(newHandlers[:idx], newHandlers[idx+1:]...)
return true
}
// handler not found - removal failed.
return false
}

func getIndex(find int, is []int) int {
for i, v := range is {
if v == find {
return i
}
}
return -1
}

func (m *handlerMappings) removeGroup(group int) bool {
m.mutex.Lock()
defer m.mutex.Unlock()

if _, ok := m.handlers[group]; !ok {
// Group doesn't exist in map, so already removed.
return false
}

for idx, handlerGroup := range m.handlerGroups {
if handlerGroup != group {
continue
}

m.handlerGroups = append(m.handlerGroups[:idx], m.handlerGroups[idx+1:]...)
delete(m.handlers, group)
// Group found, and deleted. Success!
return true
}
// Group not found in list - so already removed.
return false
}

func (m *handlerMappings) getGroups() [][]Handler {
m.mutex.RLock()
defer m.mutex.RUnlock()

allHandlers := make([][]Handler, len(m.handlerGroups))
for idx, num := range m.handlerGroups {
allHandlers[idx] = m.handlers[num]
}
return allHandlers
}
Loading

0 comments on commit f87f4f7

Please sign in to comment.