diff --git a/README.md b/README.md index c89bc88..c5548ba 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,11 @@ Connect to `ws://localhost:6008/subscribe` to start the stream The following Query Parameters are supported: - `wantedCollections` - An array of [Collection NSIDs](https://atproto.com/specs/nsid) to filter which records you receive on your stream (default empty = all collections) - - Regardless of desired collections, all subscribers recieve Account and Identity events + - `wantedCollections` supports NSID path prefixes i.e. `app.bsky.graph.*`, or `app.bsky.*`. The prefix before the `.*` must pass NSID validation and Jetstream **does not** support incomplete prefixes i.e. `app.bsky.graph.fo*`. + - Regardless of desired collections, all subscribers recieve Account and Identity events. + - You can specify at most 100 wanted collections/prefixes. - `wantedDids` - An array of Repo DIDs to filter which records you receive on your stream (Default empty = all repos) + - You can specify at most 10,000 wanted DIDs. - `cursor` - A unix microseconds timestamp cursor to begin playback from - An absent cursor or a cursor from the future will result in live-tail operation - When reconnecting, use the `time_us` from your most recently processed event and maybe provide a negative buffer (i.e. subtract a few seconds) to ensure gapless playback diff --git a/cmd/jetstream/server.go b/cmd/jetstream/server.go index 1e0c7ce..777533a 100644 --- a/cmd/jetstream/server.go +++ b/cmd/jetstream/server.go @@ -4,7 +4,10 @@ import ( "context" "fmt" "log/slog" + "net/http" + "slices" "strconv" + "strings" "sync" "time" @@ -23,17 +26,23 @@ var ( upgrader = websocket.Upgrader{} ) +type WantedCollections struct { + Prefixes []string + FullPaths map[string]struct{} +} + type Subscriber struct { - ws *websocket.Conn - realIP string - seq int64 - buf chan *[]byte - id int64 - cLk sync.Mutex - cursor *int64 - deliveredCounter prometheus.Counter - bytesCounter prometheus.Counter - wantedCollections map[string]struct{} + ws *websocket.Conn + realIP string + seq int64 + buf chan *[]byte + id int64 + cLk sync.Mutex + cursor *int64 + deliveredCounter prometheus.Counter + bytesCounter prometheus.Counter + // wantedCollections is nil if the subscriber wants all collections + wantedCollections *WantedCollections wantedDids map[string]struct{} rl *rate.Limiter } @@ -116,10 +125,8 @@ func (s *Server) Emit(ctx context.Context, e models.Event) error { } func emitToSubscriber(ctx context.Context, log *slog.Logger, sub *Subscriber, timeUS int64, did, collection string, playback bool, getEncodedEvent func() []byte) error { - if len(sub.wantedCollections) > 0 && collection != "" { - if _, ok := sub.wantedCollections[collection]; !ok { - return nil - } + if !sub.WantsCollection(collection) { + return nil } if len(sub.wantedDids) > 0 { @@ -184,26 +191,40 @@ func (s *Server) GetSeq() int64 { return s.seq } -func (s *Server) AddSubscriber(ws *websocket.Conn, realIP string, wantedCollections []string, wantedDids []string, cursor *int64) *Subscriber { +func (s *Server) AddSubscriber(ws *websocket.Conn, realIP string, wantedCollectionPrefixes []string, wantedCollections []string, wantedDids []string, cursor *int64) *Subscriber { s.lk.Lock() defer s.lk.Unlock() - colMap := make(map[string]struct{}) - for _, c := range wantedCollections { - colMap[c] = struct{}{} - } - didMap := make(map[string]struct{}) for _, d := range wantedDids { didMap[d] = struct{}{} } + // Build the WantedCollections struct + var wantedCol *WantedCollections + if len(wantedCollections) > 0 || len(wantedCollectionPrefixes) > 0 { + wantedCol = &WantedCollections{ + Prefixes: wantedCollectionPrefixes, + FullPaths: make(map[string]struct{}), + } + + // Sort the prefixes by length so we test the shortest prefixes first + slices.SortFunc(wantedCol.Prefixes, func(a, b string) int { + return len(a) - len(b) + }) + + // Add the full paths to the map + for _, c := range wantedCollections { + wantedCol.FullPaths[c] = struct{}{} + } + } + sub := Subscriber{ ws: ws, realIP: realIP, buf: make(chan *[]byte, 10_000), id: s.nextSub, - wantedCollections: colMap, + wantedCollections: wantedCol, wantedDids: didMap, cursor: cursor, deliveredCounter: eventsDelivered.WithLabelValues(realIP), @@ -219,7 +240,7 @@ func (s *Server) AddSubscriber(ws *websocket.Conn, realIP string, wantedCollecti slog.Info("adding subscriber", "real_ip", realIP, "id", sub.id, - "wantedCollections", wantedCollections, + "wantedCollections", wantedCol, "wantedDids", wantedDids, ) @@ -241,42 +262,63 @@ func (s *Server) HandleSubscribe(c echo.Context) error { ctx, cancel := context.WithCancel(c.Request().Context()) defer cancel() - ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil) - if err != nil { - return err - } - defer ws.Close() - wantedCollections := []string{} + wantedCollectionPrefixes := []string{} qWantedCollections := c.Request().URL.Query()["wantedCollections"] if len(qWantedCollections) > 0 { - for _, c := range qWantedCollections { - col, err := syntax.ParseNSID(c) + for _, wantedCol := range qWantedCollections { + if strings.HasSuffix(wantedCol, ".*") { + wantedCollectionPrefixes = append(wantedCollectionPrefixes, strings.TrimSuffix(wantedCol, "*")) + // Make sure the prefix is valid + if _, err := syntax.ParseNSID(strings.TrimSuffix(wantedCol, ".*")); err != nil { + c.String(http.StatusBadRequest, fmt.Sprintf("invalid collection prefix: %s", wantedCol)) + return fmt.Errorf("invalid collection prefix: %s", wantedCol) + } + continue + } + + col, err := syntax.ParseNSID(wantedCol) if err != nil { - return fmt.Errorf("invalid collection: %s", c) + c.String(http.StatusBadRequest, fmt.Sprintf("invalid collection: %s", wantedCol)) + return fmt.Errorf("invalid collection: %s", wantedCol) } wantedCollections = append(wantedCollections, col.String()) } } + // Reject requests with too many wanted collections + if len(wantedCollections)+len(wantedCollectionPrefixes) > 100 { + c.String(http.StatusBadRequest, "too many wanted collections") + return fmt.Errorf("too many wanted collections") + } + wantedDids := []string{} qWantedDids := c.Request().URL.Query()["wantedDids"] if len(qWantedDids) > 0 { for _, d := range qWantedDids { did, err := syntax.ParseDID(d) if err != nil { + c.String(http.StatusBadRequest, fmt.Sprintf("invalid did: %s", d)) return fmt.Errorf("invalid did: %s", d) } wantedDids = append(wantedDids, did.String()) } } + // Reject requests with too many wanted DIDs + if len(wantedDids) > 10_000 { + c.String(http.StatusBadRequest, "too many wanted DIDs") + return fmt.Errorf("too many wanted DIDs") + } + var cursor *int64 + var err error qCursor := c.Request().URL.Query().Get("cursor") if qCursor != "" { cursor = new(int64) *cursor, err = strconv.ParseInt(qCursor, 10, 64) if err != nil { + c.String(http.StatusBadRequest, fmt.Sprintf("invalid cursor: %s", qCursor)) return fmt.Errorf("invalid cursor: %s", qCursor) } @@ -286,6 +328,12 @@ func (s *Server) HandleSubscribe(c echo.Context) error { } } + ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil) + if err != nil { + return err + } + defer ws.Close() + log := slog.With("source", "server_handle_subscribe", "socket_addr", ws.RemoteAddr().String(), "real_ip", c.RealIP()) go func() { @@ -299,7 +347,7 @@ func (s *Server) HandleSubscribe(c echo.Context) error { } }() - sub := s.AddSubscriber(ws, c.RealIP(), wantedCollections, wantedDids, cursor) + sub := s.AddSubscriber(ws, c.RealIP(), wantedCollectionPrefixes, wantedCollections, wantedDids, cursor) defer s.RemoveSubscriber(sub.id) if cursor != nil { @@ -354,3 +402,26 @@ func (s *Server) HandleSubscribe(c echo.Context) error { } } } + +// WantsCollection returns true if the subscriber wants the given collection +func (sub *Subscriber) WantsCollection(collection string) bool { + if sub.wantedCollections == nil { + return true + } + + // Start with the full paths for fast lookup + if len(sub.wantedCollections.FullPaths) > 0 { + if _, match := sub.wantedCollections.FullPaths[collection]; match { + return true + } + } + + // Check the prefixes (shortest first) + for _, prefix := range sub.wantedCollections.Prefixes { + if strings.HasPrefix(collection, prefix) { + return true + } + } + + return false +}