From ddc0ad6052e9b5ce77989055f2865cf729463893 Mon Sep 17 00:00:00 2001 From: Andrew Gouin Date: Wed, 10 Jul 2024 16:02:18 -0600 Subject: [PATCH] fix(privval): Make maxReadSize configurable (#13) --- cmd/start.go | 7 +++++-- cmd/watcher.go | 10 ++++++---- signer/remote_signer.go | 14 ++++++++++---- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/cmd/start.go b/cmd/start.go index b23558d..44b16f5 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -18,6 +18,7 @@ const ( flagGRPCAddress = "grpc" flagOperator = "operator" flagSentry = "sentry" + flagMaxReadSize = "max-read-size" ) func startCmd() *cobra.Command { @@ -70,13 +71,14 @@ func startCmd() *cobra.Command { // if we're running in kubernetes, we can auto-discover sentries operator, _ := cmd.Flags().GetBool(flagOperator) sentries, _ := cmd.Flags().GetStringArray(flagSentry) + maxReadSize, _ := cmd.Flags().GetInt(flagMaxReadSize) - watcher, err := NewSentryWatcher(ctx, logger, all, hc, operator, sentries) + watcher, err := NewSentryWatcher(ctx, logger, all, hc, operator, sentries, maxReadSize) if err != nil { return err } defer logIfErr(logger, watcher.Stop) - go watcher.Watch(ctx) + go watcher.Watch(ctx, maxReadSize) waitForSignals(logger) @@ -90,6 +92,7 @@ func startCmd() *cobra.Command { cmd.Flags().StringP(flagGRPCAddress, "g", "", "GRPC address for the proxy") cmd.Flags().BoolP(flagAll, "a", false, "Connect to sentries on all nodes") cmd.Flags().String(flagLogLevel, "info", "Set log level (debug, info, error, none)") + cmd.Flags().Int(flagMaxReadSize, 1024*1024, "Max read size for privval messages") return cmd } diff --git a/cmd/watcher.go b/cmd/watcher.go index 309c04b..d6f588a 100644 --- a/cmd/watcher.go +++ b/cmd/watcher.go @@ -43,6 +43,7 @@ func NewSentryWatcher( hc signer.HorcruxConnection, operator bool, sentries []string, + maxReadSize int, ) (*SentryWatcher, error) { var clientset *kubernetes.Clientset var thisNode string @@ -79,7 +80,7 @@ func NewSentryWatcher( persistentSentries := make([]*signer.ReconnRemoteSigner, len(sentries)) for i, sentry := range sentries { dialer := net.Dialer{Timeout: 2 * time.Second} - persistentSentries[i] = signer.NewReconnRemoteSigner(sentry, logger, hc, dialer) + persistentSentries[i] = signer.NewReconnRemoteSigner(sentry, logger, hc, dialer, maxReadSize) } return &SentryWatcher{ @@ -98,7 +99,7 @@ func NewSentryWatcher( // Watch will reconcile the sentries with the kube api at a reasonable interval. // It must be called only once. -func (w *SentryWatcher) Watch(ctx context.Context) { +func (w *SentryWatcher) Watch(ctx context.Context, maxReadSize int) { for _, sentry := range w.persistentSentries { if err := sentry.Start(); err != nil { w.log.Error("Failed to start persistent sentry", "error", err) @@ -113,7 +114,7 @@ func (w *SentryWatcher) Watch(ctx context.Context) { defer timer.Stop() for { - if err := w.reconcileSentries(ctx); err != nil { + if err := w.reconcileSentries(ctx, maxReadSize); err != nil { w.log.Error("Failed to reconcile sentries with kube api", "error", err) } select { @@ -144,6 +145,7 @@ func (w *SentryWatcher) Stop() error { func (w *SentryWatcher) reconcileSentries( ctx context.Context, + maxReadSize int, ) error { configNodes := make([]string, 0) @@ -220,7 +222,7 @@ func (w *SentryWatcher) reconcileSentries( for _, newSentry := range newSentries { dialer := net.Dialer{Timeout: 2 * time.Second} - s := signer.NewReconnRemoteSigner(newSentry, w.log, w.hc, dialer) + s := signer.NewReconnRemoteSigner(newSentry, w.log, w.hc, dialer, maxReadSize) if err := s.Start(); err != nil { return fmt.Errorf("failed to start new remote signer(s): %w", err) diff --git a/signer/remote_signer.go b/signer/remote_signer.go index 034e69d..6e1a45d 100644 --- a/signer/remote_signer.go +++ b/signer/remote_signer.go @@ -31,6 +31,8 @@ type ReconnRemoteSigner struct { horcruxConnection HorcruxConnection dialer net.Dialer + + maxReadSize int } // NewReconnRemoteSigner return a ReconnRemoteSigner that will dial using the given @@ -43,12 +45,14 @@ func NewReconnRemoteSigner( logger cometlog.Logger, horcruxConnection HorcruxConnection, dialer net.Dialer, + maxReadSize int, ) *ReconnRemoteSigner { rs := &ReconnRemoteSigner{ address: address, dialer: dialer, horcruxConnection: horcruxConnection, privKey: cometcryptoed25519.GenPrivKey(), + maxReadSize: maxReadSize, } rs.BaseService = *cometservice.NewBaseService(logger, "RemoteSigner", rs) @@ -113,7 +117,7 @@ func (rs *ReconnRemoteSigner) loop() { return } - req, err := ReadMsg(conn) + req, err := ReadMsg(conn, rs.maxReadSize) if err != nil { rs.Logger.Error("readMsg", "err", err) conn.Close() @@ -147,9 +151,11 @@ func (rs *ReconnRemoteSigner) loop() { } // ReadMsg reads a message from an io.Reader -func ReadMsg(reader io.Reader) (msg cometprotoprivval.Message, err error) { - const maxRemoteSignerMsgSize = 1024 * 10 - protoReader := protoio.NewDelimitedReader(reader, maxRemoteSignerMsgSize) +func ReadMsg(reader io.Reader, maxReadSize int) (msg cometprotoprivval.Message, err error) { + if maxReadSize <= 0 { + maxReadSize = 1024 * 1024 // 1MB + } + protoReader := protoio.NewDelimitedReader(reader, maxReadSize) _, err = protoReader.ReadMsg(&msg) return msg, err }