Skip to content

Commit

Permalink
Allow for NSID prefix specification and put reasonable limits on it
Browse files Browse the repository at this point in the history
  • Loading branch information
ericvolp12 committed Sep 20, 2024
1 parent be2fbfb commit d118127
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 33 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ Connect to `ws://localhost:6008/subscribe` to start the stream
The following Query Parameters are supported:

- `wantedCollections` - An array of [Collection NSIDs](https://atproto.com/specs/nsid) to filter which records you receive on your stream (default empty = all collections)
- Regardless of desired collections, all subscribers recieve Account and Identity events
- `wantedCollections` supports NSID path prefixes i.e. `app.bsky.graph.*`, or `app.bsky.*`. The prefix before the `.*` must pass NSID validation and Jetstream **does not** support incomplete prefixes i.e. `app.bsky.graph.fo*`.
- Regardless of desired collections, all subscribers recieve Account and Identity events.
- You can specify at most 100 wanted collections/prefixes.
- `wantedDids` - An array of Repo DIDs to filter which records you receive on your stream (Default empty = all repos)
- You can specify at most 10,000 wanted DIDs.
- `cursor` - A unix microseconds timestamp cursor to begin playback from
- An absent cursor or a cursor from the future will result in live-tail operation
- When reconnecting, use the `time_us` from your most recently processed event and maybe provide a negative buffer (i.e. subtract a few seconds) to ensure gapless playback
Expand Down
135 changes: 103 additions & 32 deletions cmd/jetstream/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import (
"context"
"fmt"
"log/slog"
"net/http"
"slices"
"strconv"
"strings"
"sync"
"time"

Expand All @@ -23,17 +26,23 @@ var (
upgrader = websocket.Upgrader{}
)

type WantedCollections struct {
Prefixes []string
FullPaths map[string]struct{}
}

type Subscriber struct {
ws *websocket.Conn
realIP string
seq int64
buf chan *[]byte
id int64
cLk sync.Mutex
cursor *int64
deliveredCounter prometheus.Counter
bytesCounter prometheus.Counter
wantedCollections map[string]struct{}
ws *websocket.Conn
realIP string
seq int64
buf chan *[]byte
id int64
cLk sync.Mutex
cursor *int64
deliveredCounter prometheus.Counter
bytesCounter prometheus.Counter
// wantedCollections is nil if the subscriber wants all collections
wantedCollections *WantedCollections
wantedDids map[string]struct{}
rl *rate.Limiter
}
Expand Down Expand Up @@ -116,10 +125,8 @@ func (s *Server) Emit(ctx context.Context, e models.Event) error {
}

func emitToSubscriber(ctx context.Context, log *slog.Logger, sub *Subscriber, timeUS int64, did, collection string, playback bool, getEncodedEvent func() []byte) error {
if len(sub.wantedCollections) > 0 && collection != "" {
if _, ok := sub.wantedCollections[collection]; !ok {
return nil
}
if !sub.WantsCollection(collection) {
return nil
}

if len(sub.wantedDids) > 0 {
Expand Down Expand Up @@ -184,26 +191,40 @@ func (s *Server) GetSeq() int64 {
return s.seq
}

func (s *Server) AddSubscriber(ws *websocket.Conn, realIP string, wantedCollections []string, wantedDids []string, cursor *int64) *Subscriber {
func (s *Server) AddSubscriber(ws *websocket.Conn, realIP string, wantedCollectionPrefixes []string, wantedCollections []string, wantedDids []string, cursor *int64) *Subscriber {
s.lk.Lock()
defer s.lk.Unlock()

colMap := make(map[string]struct{})
for _, c := range wantedCollections {
colMap[c] = struct{}{}
}

didMap := make(map[string]struct{})
for _, d := range wantedDids {
didMap[d] = struct{}{}
}

// Build the WantedCollections struct
var wantedCol *WantedCollections
if len(wantedCollections) > 0 || len(wantedCollectionPrefixes) > 0 {
wantedCol = &WantedCollections{
Prefixes: wantedCollectionPrefixes,
FullPaths: make(map[string]struct{}),
}

// Sort the prefixes by length so we test the shortest prefixes first
slices.SortFunc(wantedCol.Prefixes, func(a, b string) int {
return len(a) - len(b)
})

// Add the full paths to the map
for _, c := range wantedCollections {
wantedCol.FullPaths[c] = struct{}{}
}
}

sub := Subscriber{
ws: ws,
realIP: realIP,
buf: make(chan *[]byte, 10_000),
id: s.nextSub,
wantedCollections: colMap,
wantedCollections: wantedCol,
wantedDids: didMap,
cursor: cursor,
deliveredCounter: eventsDelivered.WithLabelValues(realIP),
Expand All @@ -219,7 +240,7 @@ func (s *Server) AddSubscriber(ws *websocket.Conn, realIP string, wantedCollecti
slog.Info("adding subscriber",
"real_ip", realIP,
"id", sub.id,
"wantedCollections", wantedCollections,
"wantedCollections", wantedCol,
"wantedDids", wantedDids,
)

Expand All @@ -241,42 +262,63 @@ func (s *Server) HandleSubscribe(c echo.Context) error {
ctx, cancel := context.WithCancel(c.Request().Context())
defer cancel()

ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
if err != nil {
return err
}
defer ws.Close()

wantedCollections := []string{}
wantedCollectionPrefixes := []string{}
qWantedCollections := c.Request().URL.Query()["wantedCollections"]
if len(qWantedCollections) > 0 {
for _, c := range qWantedCollections {
col, err := syntax.ParseNSID(c)
for _, wantedCol := range qWantedCollections {
if strings.HasSuffix(wantedCol, ".*") {
wantedCollectionPrefixes = append(wantedCollectionPrefixes, strings.TrimSuffix(wantedCol, "*"))
// Make sure the prefix is valid
if _, err := syntax.ParseNSID(strings.TrimSuffix(wantedCol, ".*")); err != nil {
c.String(http.StatusBadRequest, fmt.Sprintf("invalid collection prefix: %s", wantedCol))
return fmt.Errorf("invalid collection prefix: %s", wantedCol)
}
continue
}

col, err := syntax.ParseNSID(wantedCol)
if err != nil {
return fmt.Errorf("invalid collection: %s", c)
c.String(http.StatusBadRequest, fmt.Sprintf("invalid collection: %s", wantedCol))
return fmt.Errorf("invalid collection: %s", wantedCol)
}
wantedCollections = append(wantedCollections, col.String())
}
}

// Reject requests with too many wanted collections
if len(wantedCollections)+len(wantedCollectionPrefixes) > 100 {
c.String(http.StatusBadRequest, "too many wanted collections")
return fmt.Errorf("too many wanted collections")
}

wantedDids := []string{}
qWantedDids := c.Request().URL.Query()["wantedDids"]
if len(qWantedDids) > 0 {
for _, d := range qWantedDids {
did, err := syntax.ParseDID(d)
if err != nil {
c.String(http.StatusBadRequest, fmt.Sprintf("invalid did: %s", d))
return fmt.Errorf("invalid did: %s", d)
}
wantedDids = append(wantedDids, did.String())
}
}

// Reject requests with too many wanted DIDs
if len(wantedDids) > 10_000 {
c.String(http.StatusBadRequest, "too many wanted DIDs")
return fmt.Errorf("too many wanted DIDs")
}

var cursor *int64
var err error
qCursor := c.Request().URL.Query().Get("cursor")
if qCursor != "" {
cursor = new(int64)
*cursor, err = strconv.ParseInt(qCursor, 10, 64)
if err != nil {
c.String(http.StatusBadRequest, fmt.Sprintf("invalid cursor: %s", qCursor))
return fmt.Errorf("invalid cursor: %s", qCursor)
}

Expand All @@ -286,6 +328,12 @@ func (s *Server) HandleSubscribe(c echo.Context) error {
}
}

ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
if err != nil {
return err
}
defer ws.Close()

log := slog.With("source", "server_handle_subscribe", "socket_addr", ws.RemoteAddr().String(), "real_ip", c.RealIP())

go func() {
Expand All @@ -299,7 +347,7 @@ func (s *Server) HandleSubscribe(c echo.Context) error {
}
}()

sub := s.AddSubscriber(ws, c.RealIP(), wantedCollections, wantedDids, cursor)
sub := s.AddSubscriber(ws, c.RealIP(), wantedCollectionPrefixes, wantedCollections, wantedDids, cursor)
defer s.RemoveSubscriber(sub.id)

if cursor != nil {
Expand Down Expand Up @@ -354,3 +402,26 @@ func (s *Server) HandleSubscribe(c echo.Context) error {
}
}
}

// WantsCollection returns true if the subscriber wants the given collection
func (sub *Subscriber) WantsCollection(collection string) bool {
if sub.wantedCollections == nil {
return true
}

// Start with the full paths for fast lookup
if len(sub.wantedCollections.FullPaths) > 0 {
if _, match := sub.wantedCollections.FullPaths[collection]; match {
return true
}
}

// Check the prefixes (shortest first)
for _, prefix := range sub.wantedCollections.Prefixes {
if strings.HasPrefix(collection, prefix) {
return true
}
}

return false
}

0 comments on commit d118127

Please sign in to comment.