Skip to content

Commit

Permalink
fix(privval): Make maxReadSize configurable (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
agouin authored Jul 10, 2024
1 parent 886d380 commit ddc0ad6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
7 changes: 5 additions & 2 deletions cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const (
flagGRPCAddress = "grpc"
flagOperator = "operator"
flagSentry = "sentry"
flagMaxReadSize = "max-read-size"
)

func startCmd() *cobra.Command {
Expand Down Expand Up @@ -70,13 +71,14 @@ func startCmd() *cobra.Command {
// if we're running in kubernetes, we can auto-discover sentries
operator, _ := cmd.Flags().GetBool(flagOperator)
sentries, _ := cmd.Flags().GetStringArray(flagSentry)
maxReadSize, _ := cmd.Flags().GetInt(flagMaxReadSize)

watcher, err := NewSentryWatcher(ctx, logger, all, hc, operator, sentries)
watcher, err := NewSentryWatcher(ctx, logger, all, hc, operator, sentries, maxReadSize)
if err != nil {
return err
}
defer logIfErr(logger, watcher.Stop)
go watcher.Watch(ctx)
go watcher.Watch(ctx, maxReadSize)

waitForSignals(logger)

Expand All @@ -90,6 +92,7 @@ func startCmd() *cobra.Command {
cmd.Flags().StringP(flagGRPCAddress, "g", "", "GRPC address for the proxy")
cmd.Flags().BoolP(flagAll, "a", false, "Connect to sentries on all nodes")
cmd.Flags().String(flagLogLevel, "info", "Set log level (debug, info, error, none)")
cmd.Flags().Int(flagMaxReadSize, 1024*1024, "Max read size for privval messages")

return cmd
}
Expand Down
10 changes: 6 additions & 4 deletions cmd/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func NewSentryWatcher(
hc signer.HorcruxConnection,
operator bool,
sentries []string,
maxReadSize int,
) (*SentryWatcher, error) {
var clientset *kubernetes.Clientset
var thisNode string
Expand Down Expand Up @@ -79,7 +80,7 @@ func NewSentryWatcher(
persistentSentries := make([]*signer.ReconnRemoteSigner, len(sentries))
for i, sentry := range sentries {
dialer := net.Dialer{Timeout: 2 * time.Second}
persistentSentries[i] = signer.NewReconnRemoteSigner(sentry, logger, hc, dialer)
persistentSentries[i] = signer.NewReconnRemoteSigner(sentry, logger, hc, dialer, maxReadSize)
}

return &SentryWatcher{
Expand All @@ -98,7 +99,7 @@ func NewSentryWatcher(

// Watch will reconcile the sentries with the kube api at a reasonable interval.
// It must be called only once.
func (w *SentryWatcher) Watch(ctx context.Context) {
func (w *SentryWatcher) Watch(ctx context.Context, maxReadSize int) {
for _, sentry := range w.persistentSentries {
if err := sentry.Start(); err != nil {
w.log.Error("Failed to start persistent sentry", "error", err)
Expand All @@ -113,7 +114,7 @@ func (w *SentryWatcher) Watch(ctx context.Context) {
defer timer.Stop()

for {
if err := w.reconcileSentries(ctx); err != nil {
if err := w.reconcileSentries(ctx, maxReadSize); err != nil {
w.log.Error("Failed to reconcile sentries with kube api", "error", err)
}
select {
Expand Down Expand Up @@ -144,6 +145,7 @@ func (w *SentryWatcher) Stop() error {

func (w *SentryWatcher) reconcileSentries(
ctx context.Context,
maxReadSize int,
) error {
configNodes := make([]string, 0)

Expand Down Expand Up @@ -220,7 +222,7 @@ func (w *SentryWatcher) reconcileSentries(

for _, newSentry := range newSentries {
dialer := net.Dialer{Timeout: 2 * time.Second}
s := signer.NewReconnRemoteSigner(newSentry, w.log, w.hc, dialer)
s := signer.NewReconnRemoteSigner(newSentry, w.log, w.hc, dialer, maxReadSize)

if err := s.Start(); err != nil {
return fmt.Errorf("failed to start new remote signer(s): %w", err)
Expand Down
14 changes: 10 additions & 4 deletions signer/remote_signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ type ReconnRemoteSigner struct {
horcruxConnection HorcruxConnection

dialer net.Dialer

maxReadSize int
}

// NewReconnRemoteSigner return a ReconnRemoteSigner that will dial using the given
Expand All @@ -43,12 +45,14 @@ func NewReconnRemoteSigner(
logger cometlog.Logger,
horcruxConnection HorcruxConnection,
dialer net.Dialer,
maxReadSize int,
) *ReconnRemoteSigner {
rs := &ReconnRemoteSigner{
address: address,
dialer: dialer,
horcruxConnection: horcruxConnection,
privKey: cometcryptoed25519.GenPrivKey(),
maxReadSize: maxReadSize,
}

rs.BaseService = *cometservice.NewBaseService(logger, "RemoteSigner", rs)
Expand Down Expand Up @@ -113,7 +117,7 @@ func (rs *ReconnRemoteSigner) loop() {
return
}

req, err := ReadMsg(conn)
req, err := ReadMsg(conn, rs.maxReadSize)
if err != nil {
rs.Logger.Error("readMsg", "err", err)
conn.Close()
Expand Down Expand Up @@ -147,9 +151,11 @@ func (rs *ReconnRemoteSigner) loop() {
}

// ReadMsg reads a message from an io.Reader
func ReadMsg(reader io.Reader) (msg cometprotoprivval.Message, err error) {
const maxRemoteSignerMsgSize = 1024 * 10
protoReader := protoio.NewDelimitedReader(reader, maxRemoteSignerMsgSize)
func ReadMsg(reader io.Reader, maxReadSize int) (msg cometprotoprivval.Message, err error) {
if maxReadSize <= 0 {
maxReadSize = 1024 * 1024 // 1MB
}
protoReader := protoio.NewDelimitedReader(reader, maxReadSize)
_, err = protoReader.ReadMsg(&msg)
return msg, err
}
Expand Down

0 comments on commit ddc0ad6

Please sign in to comment.