Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(privval): Make maxReadSize configurable #13

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const (
flagGRPCAddress = "grpc"
flagOperator = "operator"
flagSentry = "sentry"
flagMaxReadSize = "max-read-size"
)

func startCmd() *cobra.Command {
Expand Down Expand Up @@ -70,13 +71,14 @@ func startCmd() *cobra.Command {
// if we're running in kubernetes, we can auto-discover sentries
operator, _ := cmd.Flags().GetBool(flagOperator)
sentries, _ := cmd.Flags().GetStringArray(flagSentry)
maxReadSize, _ := cmd.Flags().GetInt(flagMaxReadSize)

watcher, err := NewSentryWatcher(ctx, logger, all, hc, operator, sentries)
watcher, err := NewSentryWatcher(ctx, logger, all, hc, operator, sentries, maxReadSize)
if err != nil {
return err
}
defer logIfErr(logger, watcher.Stop)
go watcher.Watch(ctx)
go watcher.Watch(ctx, maxReadSize)

waitForSignals(logger)

Expand All @@ -90,6 +92,7 @@ func startCmd() *cobra.Command {
cmd.Flags().StringP(flagGRPCAddress, "g", "", "GRPC address for the proxy")
cmd.Flags().BoolP(flagAll, "a", false, "Connect to sentries on all nodes")
cmd.Flags().String(flagLogLevel, "info", "Set log level (debug, info, error, none)")
cmd.Flags().Int(flagMaxReadSize, 1024*1024, "Max read size for privval messages")

return cmd
}
Expand Down
10 changes: 6 additions & 4 deletions cmd/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func NewSentryWatcher(
hc signer.HorcruxConnection,
operator bool,
sentries []string,
maxReadSize int,
) (*SentryWatcher, error) {
var clientset *kubernetes.Clientset
var thisNode string
Expand Down Expand Up @@ -79,7 +80,7 @@ func NewSentryWatcher(
persistentSentries := make([]*signer.ReconnRemoteSigner, len(sentries))
for i, sentry := range sentries {
dialer := net.Dialer{Timeout: 2 * time.Second}
persistentSentries[i] = signer.NewReconnRemoteSigner(sentry, logger, hc, dialer)
persistentSentries[i] = signer.NewReconnRemoteSigner(sentry, logger, hc, dialer, maxReadSize)
}

return &SentryWatcher{
Expand All @@ -98,7 +99,7 @@ func NewSentryWatcher(

// Watch will reconcile the sentries with the kube api at a reasonable interval.
// It must be called only once.
func (w *SentryWatcher) Watch(ctx context.Context) {
func (w *SentryWatcher) Watch(ctx context.Context, maxReadSize int) {
for _, sentry := range w.persistentSentries {
if err := sentry.Start(); err != nil {
w.log.Error("Failed to start persistent sentry", "error", err)
Expand All @@ -113,7 +114,7 @@ func (w *SentryWatcher) Watch(ctx context.Context) {
defer timer.Stop()

for {
if err := w.reconcileSentries(ctx); err != nil {
if err := w.reconcileSentries(ctx, maxReadSize); err != nil {
w.log.Error("Failed to reconcile sentries with kube api", "error", err)
}
select {
Expand Down Expand Up @@ -144,6 +145,7 @@ func (w *SentryWatcher) Stop() error {

func (w *SentryWatcher) reconcileSentries(
ctx context.Context,
maxReadSize int,
) error {
configNodes := make([]string, 0)

Expand Down Expand Up @@ -220,7 +222,7 @@ func (w *SentryWatcher) reconcileSentries(

for _, newSentry := range newSentries {
dialer := net.Dialer{Timeout: 2 * time.Second}
s := signer.NewReconnRemoteSigner(newSentry, w.log, w.hc, dialer)
s := signer.NewReconnRemoteSigner(newSentry, w.log, w.hc, dialer, maxReadSize)

if err := s.Start(); err != nil {
return fmt.Errorf("failed to start new remote signer(s): %w", err)
Expand Down
14 changes: 10 additions & 4 deletions signer/remote_signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ type ReconnRemoteSigner struct {
horcruxConnection HorcruxConnection

dialer net.Dialer

maxReadSize int
}

// NewReconnRemoteSigner return a ReconnRemoteSigner that will dial using the given
Expand All @@ -43,12 +45,14 @@ func NewReconnRemoteSigner(
logger cometlog.Logger,
horcruxConnection HorcruxConnection,
dialer net.Dialer,
maxReadSize int,
) *ReconnRemoteSigner {
rs := &ReconnRemoteSigner{
address: address,
dialer: dialer,
horcruxConnection: horcruxConnection,
privKey: cometcryptoed25519.GenPrivKey(),
maxReadSize: maxReadSize,
}

rs.BaseService = *cometservice.NewBaseService(logger, "RemoteSigner", rs)
Expand Down Expand Up @@ -113,7 +117,7 @@ func (rs *ReconnRemoteSigner) loop() {
return
}

req, err := ReadMsg(conn)
req, err := ReadMsg(conn, rs.maxReadSize)
if err != nil {
rs.Logger.Error("readMsg", "err", err)
conn.Close()
Expand Down Expand Up @@ -147,9 +151,11 @@ func (rs *ReconnRemoteSigner) loop() {
}

// ReadMsg reads a message from an io.Reader
func ReadMsg(reader io.Reader) (msg cometprotoprivval.Message, err error) {
const maxRemoteSignerMsgSize = 1024 * 10
protoReader := protoio.NewDelimitedReader(reader, maxRemoteSignerMsgSize)
func ReadMsg(reader io.Reader, maxReadSize int) (msg cometprotoprivval.Message, err error) {
if maxReadSize <= 0 {
maxReadSize = 1024 * 1024 // 1MB
}
protoReader := protoio.NewDelimitedReader(reader, maxReadSize)
_, err = protoReader.ReadMsg(&msg)
return msg, err
}
Expand Down
Loading