diff --git a/cmd/start.go b/cmd/start.go index 61a916f..5800fe0 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -18,6 +18,7 @@ const ( flagGRPCAddress = "grpc" flagOperator = "operator" flagSentry = "sentry" + flagMaxReadSize = "max-read-size" ) func startCmd() *cobra.Command { @@ -82,13 +83,14 @@ func startCmd() *cobra.Command { // if we're running in kubernetes, we can auto-discover sentries operator, _ := cmd.Flags().GetBool(flagOperator) sentries, _ := cmd.Flags().GetStringArray(flagSentry) + maxReadSize, _ := cmd.Flags().GetInt(flagMaxReadSize) - watcher, err := NewSentryWatcher(ctx, logger, all, hc, operator, sentries) + watcher, err := NewSentryWatcher(ctx, logger, all, hc, operator, sentries, maxReadSize) if err != nil { return err } defer logIfErr(logger, watcher.Stop) - go watcher.Watch(ctx) + go watcher.Watch(ctx, maxReadSize) waitForSignals(logger) @@ -102,6 +104,7 @@ func startCmd() *cobra.Command { cmd.Flags().StringP(flagGRPCAddress, "g", "", "GRPC address for the proxy") cmd.Flags().BoolP(flagAll, "a", false, "Connect to sentries on all nodes") cmd.Flags().String(flagLogLevel, "info", "Set log level (debug, info, error, none)") + cmd.Flags().Int(flagMaxReadSize, 1024*1024, "Max read size for privval messages") return cmd } diff --git a/cmd/watcher.go b/cmd/watcher.go index 1ec70a3..9df8055 100644 --- a/cmd/watcher.go +++ b/cmd/watcher.go @@ -43,6 +43,7 @@ func NewSentryWatcher( hc signer.HorcruxConnection, operator bool, sentries []string, + maxReadSize int, ) (*SentryWatcher, error) { var clientset *kubernetes.Clientset var thisNode string @@ -79,7 +80,7 @@ func NewSentryWatcher( persistentSentries := make([]*signer.ReconnRemoteSigner, len(sentries)) for i, sentry := range sentries { dialer := net.Dialer{Timeout: 2 * time.Second} - persistentSentries[i] = signer.NewReconnRemoteSigner(sentry, logger, hc, dialer) + persistentSentries[i] = signer.NewReconnRemoteSigner(sentry, logger, hc, dialer, maxReadSize) } return &SentryWatcher{ @@ -98,7 +99,7 @@ func NewSentryWatcher( // Watch will reconcile the sentries with the kube api at a reasonable interval. // It must be called only once. -func (w *SentryWatcher) Watch(ctx context.Context) { +func (w *SentryWatcher) Watch(ctx context.Context, maxReadSize int) { for _, sentry := range w.persistentSentries { if err := sentry.Start(); err != nil { w.log.Error("Failed to start persistent sentry", "error", err) @@ -113,7 +114,7 @@ func (w *SentryWatcher) Watch(ctx context.Context) { defer timer.Stop() for { - if err := w.reconcileSentries(ctx); err != nil { + if err := w.reconcileSentries(ctx, maxReadSize); err != nil { w.log.Error("Failed to reconcile sentries with kube api", "error", err) } select { @@ -144,6 +145,7 @@ func (w *SentryWatcher) Stop() error { func (w *SentryWatcher) reconcileSentries( ctx context.Context, + maxReadSize int, ) error { configNodes := make([]string, 0) @@ -220,7 +222,7 @@ func (w *SentryWatcher) reconcileSentries( for _, newSentry := range newSentries { dialer := net.Dialer{Timeout: 2 * time.Second} - s := signer.NewReconnRemoteSigner(newSentry, w.log, w.hc, dialer) + s := signer.NewReconnRemoteSigner(newSentry, w.log, w.hc, dialer, maxReadSize) if err := s.Start(); err != nil { return fmt.Errorf("failed to start new remote signer(s): %w", err) diff --git a/signer/remote_signer.go b/signer/remote_signer.go index 1c1c3de..ac18582 100644 --- a/signer/remote_signer.go +++ b/signer/remote_signer.go @@ -33,6 +33,8 @@ type ReconnRemoteSigner struct { horcruxConnection HorcruxConnection dialer net.Dialer + + maxReadSize int } // NewReconnRemoteSigner return a ReconnRemoteSigner that will dial using the given @@ -45,6 +47,7 @@ func NewReconnRemoteSigner( logger *slog.Logger, horcruxConnection HorcruxConnection, dialer net.Dialer, + maxReadSize int, ) *ReconnRemoteSigner { rs := &ReconnRemoteSigner{ logger: logger, @@ -52,6 +55,7 @@ func NewReconnRemoteSigner( dialer: dialer, horcruxConnection: horcruxConnection, privKey: cometcryptoed25519.GenPrivKey(), + maxReadSize: maxReadSize, } rs.BaseService = *cometservice.NewBaseService(nil, "RemoteSigner", rs) @@ -116,7 +120,7 @@ func (rs *ReconnRemoteSigner) loop() { return } - req, err := ReadMsg(conn) + req, err := ReadMsg(conn, rs.maxReadSize) if err != nil { rs.logger.Error("readMsg", "err", err) conn.Close() @@ -150,9 +154,11 @@ func (rs *ReconnRemoteSigner) loop() { } // ReadMsg reads a message from an io.Reader -func ReadMsg(reader io.Reader) (msg cometprotoprivval.Message, err error) { - const maxRemoteSignerMsgSize = 1024 * 10 - protoReader := protoio.NewDelimitedReader(reader, maxRemoteSignerMsgSize) +func ReadMsg(reader io.Reader, maxReadSize int) (msg cometprotoprivval.Message, err error) { + if maxReadSize <= 0 { + maxReadSize = 1024 * 1024 // 1MB + } + protoReader := protoio.NewDelimitedReader(reader, maxReadSize) _, err = protoReader.ReadMsg(&msg) return msg, err }