diff --git a/cmd/splitter/main.go b/cmd/splitter/main.go deleted file mode 100644 index a6d92e17e..000000000 --- a/cmd/splitter/main.go +++ /dev/null @@ -1,147 +0,0 @@ -package main - -import ( - "context" - "os" - "os/signal" - "syscall" - "time" - - "github.com/bluesky-social/indigo/bgs" - "github.com/bluesky-social/indigo/util/version" - _ "go.uber.org/automaxprocs" - - _ "net/http/pprof" - - _ "github.com/joho/godotenv/autoload" - - logging "github.com/ipfs/go-log" - "github.com/urfave/cli/v2" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" - "go.opentelemetry.io/otel/sdk/resource" - tracesdk "go.opentelemetry.io/otel/sdk/trace" - semconv "go.opentelemetry.io/otel/semconv/v1.4.0" -) - -var log = logging.Logger("splitter") - -func init() { - // control log level using, eg, GOLOG_LOG_LEVEL=debug - //logging.SetAllLoggers(logging.LevelDebug) -} - -func main() { - run(os.Args) -} - -func run(args []string) { - app := cli.App{ - Name: "splitter", - Usage: "firehose proxy", - Version: version.Version, - } - - app.Flags = []cli.Flag{ - &cli.BoolFlag{ - Name: "crawl-insecure-ws", - Usage: "when connecting to PDS instances, use ws:// instead of wss://", - }, - &cli.StringFlag{ - Name: "api-listen", - Value: ":2480", - }, - &cli.StringFlag{ - Name: "metrics-listen", - Value: ":2481", - EnvVars: []string{"SPLITTER_METRICS_LISTEN"}, - }, - } - - app.Action = Splitter - err := app.Run(os.Args) - if err != nil { - log.Fatal(err) - } -} - -func Splitter(cctx *cli.Context) error { - // Trap SIGINT to trigger a shutdown. - signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) - - // Enable OTLP HTTP exporter - // For relevant environment variables: - // https://pkg.go.dev/go.opentelemetry.io/otel/exporters/otlp/otlptrace#readme-environment-variables - // At a minimum, you need to set - // OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 - if ep := os.Getenv("OTEL_EXPORTER_OTLP_ENDPOINT"); ep != "" { - log.Infow("setting up trace exporter", "endpoint", ep) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - exp, err := otlptracehttp.New(ctx) - if err != nil { - log.Fatalw("failed to create trace exporter", "error", err) - } - defer func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if err := exp.Shutdown(ctx); err != nil { - log.Errorw("failed to shutdown trace exporter", "error", err) - } - }() - - tp := tracesdk.NewTracerProvider( - tracesdk.WithBatcher(exp), - tracesdk.WithResource(resource.NewWithAttributes( - semconv.SchemaURL, - semconv.ServiceNameKey.String("splitter"), - attribute.String("env", os.Getenv("ENVIRONMENT")), // DataDog - attribute.String("environment", os.Getenv("ENVIRONMENT")), // Others - attribute.Int64("ID", 1), - )), - ) - otel.SetTracerProvider(tp) - } - - spl := splitter.New(cctx.String("bgs-host")) - - // set up metrics endpoint - go func() { - if err := spl.StartMetrics(cctx.String("metrics-listen")); err != nil { - log.Fatalf("failed to start metrics endpoint: %s", err) - } - }() - - runErr := make(chan error, 1) - - go func() { - err := spl.Start(cctx.String("api-listen")) - runErr <- err - }() - - log.Infow("startup complete") - select { - case <-signals: - log.Info("received shutdown signal") - errs := spl.Shutdown() - for err := range errs { - log.Errorw("error during Splitter shutdown", "err", err) - } - case err := <-runErr: - if err != nil { - log.Errorw("error during Splitter startup", "err", err) - } - log.Info("shutting down") - errs := bgs.Shutdown() - for err := range errs { - log.Errorw("error during Splitter shutdown", "err", err) - } - } - - log.Info("shutdown complete") - - return nil -} diff --git a/splitter/splitter.go b/splitter/splitter.go index efe573a09..059a82fcd 100644 --- a/splitter/splitter.go +++ b/splitter/splitter.go @@ -4,14 +4,25 @@ import ( "context" "fmt" "math/rand" + "net" + "net/http" + "strconv" + "strings" "sync" "time" + "github.com/bluesky-social/indigo/bgs" events "github.com/bluesky-social/indigo/events" "github.com/bluesky-social/indigo/events/schedulers/sequential" + lexutil "github.com/bluesky-social/indigo/lex/util" "github.com/bluesky-social/indigo/models" "github.com/gorilla/websocket" logging "github.com/ipfs/go-log" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + promclient "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + dto "github.com/prometheus/client_model/go" ) var log = logging.Logger("splitter") @@ -20,23 +31,315 @@ type Splitter struct { Host string erb *EventRingBuffer events *events.EventManager + + // Management of Socket Consumers + consumersLk sync.RWMutex + nextConsumerID uint64 + consumers map[uint64]*SocketConsumer } -func NewSplitter(host string, persister events.EventPersistence) *Splitter { +func NewSplitter(host string) *Splitter { erb := NewEventRingBuffer(20000, 1000) em := events.NewEventManager(erb) return &Splitter{ - Host: host, - erb: erb, - events: em, + Host: host, + erb: erb, + events: em, + consumers: make(map[uint64]*SocketConsumer), + } +} + +func (s *Splitter) Start(addr string) error { + var lc net.ListenConfig + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + go s.subscribeWithRedialer(context.Background(), s.Host, 0) + + li, err := lc.Listen(ctx, "tcp", addr) + if err != nil { + return err } + return s.StartWithListener(li) +} + +func (s *Splitter) StartMetrics(listen string) error { + http.Handle("/metrics", promhttp.Handler()) + return http.ListenAndServe(listen, nil) } -func (s *Splitter) Start() error { +func (s *Splitter) Shutdown() error { return nil } +func (s *Splitter) StartWithListener(listen net.Listener) error { + e := echo.New() + e.HideBanner = true + + e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ + AllowOrigins: []string{"http://localhost:*", "https://bgs.bsky-sandbox.dev"}, + AllowHeaders: []string{echo.HeaderOrigin, echo.HeaderContentType, echo.HeaderAccept, echo.HeaderAuthorization}, + })) + + /* + if !s.ssl { + e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ + Format: "method=${method}, uri=${uri}, status=${status} latency=${latency_human}\n", + })) + } else { + e.Use(middleware.LoggerWithConfig(middleware.DefaultLoggerConfig)) + } + */ + + e.Use(bgs.MetricsMiddleware) + + e.HTTPErrorHandler = func(err error, ctx echo.Context) { + switch err := err.(type) { + case *echo.HTTPError: + if err2 := ctx.JSON(err.Code, map[string]any{ + "error": err.Message, + }); err2 != nil { + log.Errorf("Failed to write http error: %s", err2) + } + default: + sendHeader := true + if ctx.Path() == "/xrpc/com.atproto.sync.subscribeRepos" { + sendHeader = false + } + + log.Warnf("HANDLER ERROR: (%s) %s", ctx.Path(), err) + + if strings.HasPrefix(ctx.Path(), "/admin/") { + ctx.JSON(500, map[string]any{ + "error": err.Error(), + }) + return + } + + if sendHeader { + ctx.Response().WriteHeader(500) + } + } + } + + // TODO: this API is temporary until we formalize what we want here + + e.GET("/xrpc/com.atproto.sync.subscribeRepos", s.EventsHandler) + e.GET("/xrpc/_health", s.HandleHealthCheck) + + // In order to support booting on random ports in tests, we need to tell the + // Echo instance it's already got a port, and then use its StartServer + // method to re-use that listener. + e.Listener = listen + srv := &http.Server{} + return e.StartServer(srv) +} + +type HealthStatus struct { + Status string `json:"status"` + Message string `json:"msg,omitempty"` +} + +func (s *Splitter) HandleHealthCheck(c echo.Context) error { + return c.JSON(200, HealthStatus{Status: "ok"}) +} + +func (s *Splitter) EventsHandler(c echo.Context) error { + var since *int64 + if sinceVal := c.QueryParam("cursor"); sinceVal != "" { + sval, err := strconv.ParseInt(sinceVal, 10, 64) + if err != nil { + return err + } + since = &sval + } + + ctx, cancel := context.WithCancel(c.Request().Context()) + defer cancel() + + // TODO: authhhh + conn, err := websocket.Upgrade(c.Response(), c.Request(), c.Response().Header(), 10<<10, 10<<10) + if err != nil { + return fmt.Errorf("upgrading websocket: %w", err) + } + + lastWriteLk := sync.Mutex{} + lastWrite := time.Now() + + // Start a goroutine to ping the client every 30 seconds to check if it's + // still alive. If the client doesn't respond to a ping within 5 seconds, + // we'll close the connection and teardown the consumer. + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + lastWriteLk.Lock() + lw := lastWrite + lastWriteLk.Unlock() + + if time.Since(lw) < 30*time.Second { + continue + } + + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil { + log.Errorf("failed to ping client: %s", err) + cancel() + return + } + case <-ctx.Done(): + return + } + } + }() + + conn.SetPingHandler(func(message string) error { + err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second*60)) + if err == websocket.ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err + }) + + // Start a goroutine to read messages from the client and discard them. + go func() { + for { + _, _, err := conn.ReadMessage() + if err != nil { + log.Errorf("failed to read message from client: %s", err) + cancel() + return + } + } + }() + + ident := c.RealIP() + "-" + c.Request().UserAgent() + + evts, cleanup, err := s.events.Subscribe(ctx, ident, func(evt *events.XRPCStreamEvent) bool { return true }, since) + if err != nil { + return err + } + defer cleanup() + + // Keep track of the consumer for metrics and admin endpoints + consumer := SocketConsumer{ + RemoteAddr: c.RealIP(), + UserAgent: c.Request().UserAgent(), + ConnectedAt: time.Now(), + } + sentCounter := eventsSentCounter.WithLabelValues(consumer.RemoteAddr, consumer.UserAgent) + consumer.EventsSent = sentCounter + + consumerID := s.registerConsumer(&consumer) + defer s.cleanupConsumer(consumerID) + + log.Infow("new consumer", + "remote_addr", consumer.RemoteAddr, + "user_agent", consumer.UserAgent, + "cursor", since, + "consumer_id", consumerID, + ) + + header := events.EventHeader{Op: events.EvtKindMessage} + for { + select { + case evt := <-evts: + wc, err := conn.NextWriter(websocket.BinaryMessage) + if err != nil { + log.Errorf("failed to get next writer: %s", err) + return err + } + + var obj lexutil.CBOR + + switch { + case evt.Error != nil: + header.Op = events.EvtKindErrorFrame + obj = evt.Error + case evt.RepoCommit != nil: + header.MsgType = "#commit" + obj = evt.RepoCommit + case evt.RepoHandle != nil: + header.MsgType = "#handle" + obj = evt.RepoHandle + case evt.RepoInfo != nil: + header.MsgType = "#info" + obj = evt.RepoInfo + case evt.RepoMigrate != nil: + header.MsgType = "#migrate" + obj = evt.RepoMigrate + case evt.RepoTombstone != nil: + header.MsgType = "#tombstone" + obj = evt.RepoTombstone + default: + return fmt.Errorf("unrecognized event kind") + } + + if err := header.MarshalCBOR(wc); err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + + if err := obj.MarshalCBOR(wc); err != nil { + return fmt.Errorf("failed to write event: %w", err) + } + + if err := wc.Close(); err != nil { + return fmt.Errorf("failed to flush-close our event write: %w", err) + } + lastWriteLk.Lock() + lastWrite = time.Now() + lastWriteLk.Unlock() + sentCounter.Inc() + case <-ctx.Done(): + return nil + } + } +} + +type SocketConsumer struct { + UserAgent string + RemoteAddr string + ConnectedAt time.Time + EventsSent promclient.Counter +} + +func (s *Splitter) registerConsumer(c *SocketConsumer) uint64 { + s.consumersLk.Lock() + defer s.consumersLk.Unlock() + + id := s.nextConsumerID + s.nextConsumerID++ + + s.consumers[id] = c + + return id +} + +func (s *Splitter) cleanupConsumer(id uint64) { + s.consumersLk.Lock() + defer s.consumersLk.Unlock() + + c := s.consumers[id] + + var m = &dto.Metric{} + if err := c.EventsSent.Write(m); err != nil { + log.Errorf("failed to get sent counter: %s", err) + } + + log.Infow("consumer disconnected", + "consumer_id", id, + "remote_addr", c.RemoteAddr, + "user_agent", c.UserAgent, + "events_sent", m.Counter.GetValue()) + + delete(s.consumers, id) +} + func sleepForBackoff(b int) time.Duration { if b == 0 { return 0 @@ -139,6 +442,7 @@ func (rc *ringChunk) events() []*events.XRPCStreamEvent { } func (er *EventRingBuffer) Persist(ctx context.Context, evt *events.XRPCStreamEvent) error { + fmt.Println("persist event", sequenceForEvent(evt)) er.lk.Lock() defer er.lk.Unlock()