diff --git a/crowdsec/crowdsec.go b/crowdsec/crowdsec.go index 804bdb35..3c0fab7f 100644 --- a/crowdsec/crowdsec.go +++ b/crowdsec/crowdsec.go @@ -15,6 +15,7 @@ package crowdsec import ( + "context" "errors" "fmt" "net" @@ -235,7 +236,7 @@ func (c *CrowdSec) Start() error { return err } - c.bouncer.Run() + c.bouncer.Run(context.Background()) return nil } diff --git a/internal/bouncer/bouncer.go b/internal/bouncer/bouncer.go index 524263fa..67f94c45 100644 --- a/internal/bouncer/bouncer.go +++ b/internal/bouncer/bouncer.go @@ -25,11 +25,8 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" csbouncer "github.com/crowdsecurity/go-cs-bouncer" - "github.com/crowdsecurity/go-cs-lib/ptr" - "github.com/sirupsen/logrus" "go.uber.org/zap" - "go.uber.org/zap/zapcore" ) const ( @@ -96,16 +93,6 @@ func New(apiKey, apiURL, tickerInterval string, logger *zap.Logger) (*Bouncer, e }, nil } -func generateInstanceID(t time.Time) (string, error) { - r := rand.New(rand.NewSource(t.Unix())) - b := [4]byte{} - if _, err := r.Read(b[:]); err != nil { - return "", err - } - - return hex.EncodeToString(b[:]), nil -} - // EnableStreaming enables usage of the StreamBouncer (instead of the LiveBouncer). func (b *Bouncer) EnableStreaming() { b.useStreamingBouncer = true @@ -118,45 +105,8 @@ func (b *Bouncer) EnableHardFails() { b.streamingBouncer.RetryInitialConnect = false } -func (b *Bouncer) zapField() zapcore.Field { - return zap.String("instance_id", b.instanceID) -} - -func (b *Bouncer) updateMetrics(m *models.RemediationComponentsMetrics, interval time.Duration) { - - m.Name = userAgentName // instance ID? Is name provided when creating bouncer in CrowdSec, it seems - m.Version = ptr.Of(userAgentVersion) - m.Type = userAgentName - m.UtcStartupTimestamp = ptr.Of(b.startedAt.UTC().Unix()) - - activeDecisions := "active_decisions" // TODO: specific values allowed? Seem to be Prometheus metrics, though - value := float64(20) // TODO: track and get actual number; per origin and type? - origin := "127.0.0.30" // TODO: bouncer IP? Or original source of decisions? - ipType := "ipv4" // TODO: IP type from bouncer? - - metric := &models.DetailedMetrics{ - Meta: &models.MetricsMeta{ - UtcNowTimestamp: ptr.Of(time.Now().Unix()), - WindowSizeSeconds: ptr.Of(int64(interval.Seconds())), - }, - Items: []*models.MetricsDetailItem{ - { - Name: ptr.Of(activeDecisions), - Value: ptr.Of(value), - Labels: map[string]string{ - "origin": origin, - "ip_type": ipType, - }, - Unit: ptr.Of("ip"), - }, - }, - } - - m.Metrics = append(m.Metrics, metric) -} - // Init initializes the Bouncer -func (b *Bouncer) Init() error { +func (b *Bouncer) Init() (err error) { // override CrowdSec's default logrus logging b.overrideLogrusLogger() @@ -168,97 +118,53 @@ func (b *Bouncer) Init() error { // initialize the CrowdSec live bouncer if !b.useStreamingBouncer { b.logger.Info("initializing live bouncer", b.zapField()) - if err := b.liveBouncer.Init(); err != nil { + if err = b.liveBouncer.Init(); err != nil { return err } - b.liveBouncer.MetricsInterval = metricsInterval - - m, err := csbouncer.NewMetricsProvider( - b.liveBouncer.APIClient, - userAgentName, - b.updateMetrics, - logrus.StandardLogger(), // TODO: move around? - ) - if err != nil { - return fmt.Errorf("failed creating metrics provider: %w", err) + if b.metricsProvider, err = newMetricsProvider(b.liveBouncer.APIClient, b.updateMetrics, metricsInterval); err != nil { + return err } - m.Interval = metricsInterval - - b.metricsProvider = m - return nil } // initialize the CrowdSec streaming bouncer b.logger.Info("initializing streaming bouncer", b.zapField()) - if err := b.streamingBouncer.Init(); err != nil { + if err = b.streamingBouncer.Init(); err != nil { return err } - b.streamingBouncer.MetricsInterval = metricsInterval - - m, err := csbouncer.NewMetricsProvider( - b.streamingBouncer.APIClient, - userAgentName, - b.updateMetrics, - logrus.StandardLogger(), // TODO: move around? - ) - if err != nil { - return fmt.Errorf("failed creating metrics provider: %w", err) + if b.metricsProvider, err = newMetricsProvider(b.streamingBouncer.APIClient, b.updateMetrics, metricsInterval); err != nil { + return err } - m.Interval = metricsInterval - - b.metricsProvider = m - return nil } // Run starts the Bouncer processes -func (b *Bouncer) Run() { +func (b *Bouncer) Run(ctx context.Context) { b.startMu.Lock() defer b.startMu.Unlock() if b.started { return } + b.wg = &sync.WaitGroup{} + b.ctx, b.cancel = context.WithCancel(ctx) + b.started = true b.startedAt = time.Now() b.logger.Info("started", b.zapField()) - b.wg = &sync.WaitGroup{} - b.ctx, b.cancel = context.WithCancel(context.Background()) - - // the LiveBouncer has nothing to run in the background; return early + // when using the live bouncer only the metrics provider needs + // to be initialized. Return early without starting other processes. if !b.useStreamingBouncer { - // TODO: deduplicate this logic; helper function? - - b.wg.Add(1) - go func() { - defer b.wg.Done() - - b.logger.Debug("starting metrics provider", b.zapField()) - if err := b.metricsProvider.Run(b.ctx); err != nil { - if err.Error() == "metric provider halted" { - b.logger.Info("metrics provider stopped", b.zapField()) - } else { - b.logger.Error("failed running metrics provider", b.zapField(), zap.Error(err)) - } - } - }() + b.startMetricsProvider(b.ctx) return } - b.wg.Add(1) - go func() { - defer b.wg.Done() - b.logger.Debug("starting streaming bouncer", b.zapField()) - b.streamingBouncer.Run(b.ctx) - }() - // TODO: close the stream nicely when the bouncer needs to quit. This is not done // in the csbouncer package itself when canceling. // TODO: wait with processing until we know we're successfully connected to @@ -266,74 +172,9 @@ func (b *Bouncer) Run() { // directly, but we could use the heartbeat service before starting to run? // That can also be useful for testing the LiveBouncer at startup. - b.wg.Add(1) - go func() { - defer b.wg.Done() - - b.logger.Debug("starting decision processing", b.zapField()) - - for { - select { - case <-b.ctx.Done(): - b.logger.Info("processing new and deleted decisions stopped", b.zapField()) - return - case decisions := <-b.streamingBouncer.Stream: - if decisions == nil { - continue - } - // TODO: deletions seem to include all old decisions that had already expired; CrowdSec bug or intended behavior? - // TODO: process in separate goroutines/waitgroup? - if numberOfDeletedDecisions := len(decisions.Deleted); numberOfDeletedDecisions > 0 { - b.logger.Debug(fmt.Sprintf("processing %d deleted decisions", numberOfDeletedDecisions), b.zapField()) - for _, decision := range decisions.Deleted { - if err := b.delete(decision); err != nil { - b.logger.Error(fmt.Sprintf("unable to delete decision for %q: %s", *decision.Value, err), b.zapField()) - } else { - if numberOfDeletedDecisions <= maxNumberOfDecisionsToLog { - b.logger.Debug(fmt.Sprintf("deleted %q (scope: %s)", *decision.Value, *decision.Scope), b.zapField()) - } - } - } - if numberOfDeletedDecisions > maxNumberOfDecisionsToLog { - b.logger.Debug(fmt.Sprintf("skipped logging for %d deleted decisions", numberOfDeletedDecisions), b.zapField()) - } - b.logger.Debug(fmt.Sprintf("finished processing %d deleted decisions", numberOfDeletedDecisions), b.zapField()) - } - - // TODO: process in separate goroutines/waitgroup? - if numberOfNewDecisions := len(decisions.New); numberOfNewDecisions > 0 { - b.logger.Debug(fmt.Sprintf("processing %d new decisions", numberOfNewDecisions), b.zapField()) - for _, decision := range decisions.New { - if err := b.add(decision); err != nil { - b.logger.Error(fmt.Sprintf("unable to insert decision for %q: %s", *decision.Value, err), b.zapField()) - } else { - if numberOfNewDecisions <= maxNumberOfDecisionsToLog { - b.logger.Debug(fmt.Sprintf("adding %q (scope: %s) for %q", *decision.Value, *decision.Scope, *decision.Duration), b.zapField()) - } - } - } - if numberOfNewDecisions > maxNumberOfDecisionsToLog { - b.logger.Debug(fmt.Sprintf("skipped logging for %d new decisions", numberOfNewDecisions), b.zapField()) - } - b.logger.Debug(fmt.Sprintf("finished processing %d new decisions", numberOfNewDecisions), b.zapField()) - } - } - } - }() - - b.wg.Add(1) - go func() { - defer b.wg.Done() - - b.logger.Debug("starting metrics provider", b.zapField()) - if err := b.metricsProvider.Run(b.ctx); err != nil { - if err.Error() == "metric provider halted" { - b.logger.Info("metrics provider stopped", b.zapField()) - } else { - b.logger.Error("failed running metrics provider", b.zapField(), zap.Error(err)) - } - } - }() + b.startStreamingBouncer(b.ctx) + b.startProcessingDecisions(b.ctx) + b.startMetricsProvider(b.ctx) } // Shutdown stops the Bouncer @@ -344,44 +185,19 @@ func (b *Bouncer) Shutdown() error { return nil } - b.logger.Info("stopping", b.zapField()) - defer func() { - b.stopped = true - b.logger.Info("finished", b.zapField()) - b.logger.Sync() // nolint - }() - - // TODO: verify this is OK - // // the LiveBouncer has nothing to do on shutdown - // if !b.useStreamingBouncer { - // return nil - // } + b.logger.Info("stopping ...", b.zapField()) b.cancel() b.wg.Wait() // TODO: clean shutdown of the streaming bouncer channel reading //b.store = nil // TODO(hs): setting this to nil without reinstantiating it, leads to errors; do this properly. - return nil -} - -// Add adds a Decision to the storage -func (b *Bouncer) add(decision *models.Decision) error { - // TODO: provide additional ways for storing the decisions - // (i.e. radix tree is not always the most efficient one, but it's great for matching IPs to ranges) - // Knowing that a key is a CIDR does allow to check an IP with the .Contains() function, but still - // requires looping through the ranges - - // TODO: store additional data about the decision (i.e. time added to store, etc) - // TODO: wrap the *models.Decision in an internal model (after validation)? - - return b.store.add(decision) -} + b.stopped = true + b.logger.Info("finished", b.zapField()) + b.logger.Sync() // nolint -// Delete removes a Decision from the storage -func (b *Bouncer) delete(decision *models.Decision) error { - return b.store.delete(decision) + return nil } // IsAllowed checks if an IP is allowed or not @@ -404,29 +220,12 @@ func (b *Bouncer) IsAllowed(ip net.IP) (bool, *models.Decision, error) { return isAllowed, nil, nil } -func (b *Bouncer) retrieveDecision(ip net.IP) (*models.Decision, error) { - if b.useStreamingBouncer { - return b.store.get(ip) - } - - decision, err := b.liveBouncer.Get(ip.String()) - if err != nil { - fields := []zapcore.Field{ - b.zapField(), - zap.String("address", b.liveBouncer.APIUrl), - zap.Error(err), - } - if b.shouldFailHard { - b.logger.Fatal(err.Error(), fields...) - } else { - b.logger.Error(err.Error(), fields...) - } - return nil, nil // when not failing hard, we return no error - } - - if len(*decision) >= 1 { - return (*decision)[0], nil // TODO: decide if choosing the first decision is OK +func generateInstanceID(t time.Time) (string, error) { + r := rand.New(rand.NewSource(t.Unix())) + b := [4]byte{} + if _, err := r.Read(b[:]); err != nil { + return "", err } - return nil, nil + return hex.EncodeToString(b[:]), nil } diff --git a/internal/bouncer/bouncer_test.go b/internal/bouncer/bouncer_test.go index 015de3ba..f5ec2bc3 100644 --- a/internal/bouncer/bouncer_test.go +++ b/internal/bouncer/bouncer_test.go @@ -1,6 +1,7 @@ package bouncer import ( + "context" "fmt" "net" "net/url" @@ -147,7 +148,7 @@ func TestStreamingBouncer(t *testing.T) { // run the bouncer; makes it make a call to the mocked CrowdSec API // this should be called after the httpmock is activated, because otherwise the bouncer // will try to call an actual CrowdSec instance - b.Run() + b.Run(context.Background()) // allow the bouncer a bit of time to retrieve and store the mocked rules time.Sleep(1 * time.Second) diff --git a/internal/bouncer/decisions.go b/internal/bouncer/decisions.go new file mode 100644 index 00000000..74cb8722 --- /dev/null +++ b/internal/bouncer/decisions.go @@ -0,0 +1,123 @@ +package bouncer + +import ( + "context" + "fmt" + "net" + + "github.com/crowdsecurity/crowdsec/pkg/models" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func (b *Bouncer) startStreamingBouncer(ctx context.Context) { + b.wg.Add(1) + go func() { + defer b.wg.Done() + b.logger.Debug("starting streaming bouncer", b.zapField()) + b.streamingBouncer.Run(ctx) + }() +} + +func (b *Bouncer) startProcessingDecisions(ctx context.Context) { + b.wg.Add(1) + go func() { + defer b.wg.Done() + + b.logger.Debug("starting decision processing", b.zapField()) + + for { + select { + case <-b.ctx.Done(): + b.logger.Info("processing new and deleted decisions stopped", b.zapField()) + return + case decisions := <-b.streamingBouncer.Stream: + if decisions == nil { + continue + } + // TODO: deletions seem to include all old decisions that had already expired; CrowdSec bug or intended behavior? + // TODO: process in separate goroutines/waitgroup? + if numberOfDeletedDecisions := len(decisions.Deleted); numberOfDeletedDecisions > 0 { + b.logger.Debug(fmt.Sprintf("processing %d deleted decisions", numberOfDeletedDecisions), b.zapField()) + for _, decision := range decisions.Deleted { + if err := b.delete(decision); err != nil { + b.logger.Error(fmt.Sprintf("unable to delete decision for %q: %s", *decision.Value, err), b.zapField()) + } else { + if numberOfDeletedDecisions <= maxNumberOfDecisionsToLog { + b.logger.Debug(fmt.Sprintf("deleted %q (scope: %s)", *decision.Value, *decision.Scope), b.zapField()) + } + } + } + if numberOfDeletedDecisions > maxNumberOfDecisionsToLog { + b.logger.Debug(fmt.Sprintf("skipped logging for %d deleted decisions", numberOfDeletedDecisions), b.zapField()) + } + b.logger.Debug(fmt.Sprintf("finished processing %d deleted decisions", numberOfDeletedDecisions), b.zapField()) + } + + // TODO: process in separate goroutines/waitgroup? + if numberOfNewDecisions := len(decisions.New); numberOfNewDecisions > 0 { + b.logger.Debug(fmt.Sprintf("processing %d new decisions", numberOfNewDecisions), b.zapField()) + for _, decision := range decisions.New { + if err := b.add(decision); err != nil { + b.logger.Error(fmt.Sprintf("unable to insert decision for %q: %s", *decision.Value, err), b.zapField()) + } else { + if numberOfNewDecisions <= maxNumberOfDecisionsToLog { + b.logger.Debug(fmt.Sprintf("adding %q (scope: %s) for %q", *decision.Value, *decision.Scope, *decision.Duration), b.zapField()) + } + } + } + if numberOfNewDecisions > maxNumberOfDecisionsToLog { + b.logger.Debug(fmt.Sprintf("skipped logging for %d new decisions", numberOfNewDecisions), b.zapField()) + } + b.logger.Debug(fmt.Sprintf("finished processing %d new decisions", numberOfNewDecisions), b.zapField()) + } + } + } + }() +} + +// Add adds a Decision to the storage +func (b *Bouncer) add(decision *models.Decision) error { + + // TODO: provide additional ways for storing the decisions + // (i.e. radix tree is not always the most efficient one, but it's great for matching IPs to ranges) + // Knowing that a key is a CIDR does allow to check an IP with the .Contains() function, but still + // requires looping through the ranges + + // TODO: store additional data about the decision (i.e. time added to store, etc) + // TODO: wrap the *models.Decision in an internal model (after validation)? + + return b.store.add(decision) +} + +// Delete removes a Decision from the storage +func (b *Bouncer) delete(decision *models.Decision) error { + return b.store.delete(decision) +} + +func (b *Bouncer) retrieveDecision(ip net.IP) (*models.Decision, error) { + if b.useStreamingBouncer { + return b.store.get(ip) + } + + decision, err := b.liveBouncer.Get(ip.String()) + if err != nil { + fields := []zapcore.Field{ + b.zapField(), + zap.String("address", b.liveBouncer.APIUrl), + zap.Error(err), + } + if b.shouldFailHard { + b.logger.Fatal(err.Error(), fields...) + } else { + b.logger.Error(err.Error(), fields...) + } + return nil, nil // when not failing hard, we return no error + } + + if len(*decision) >= 1 { + return (*decision)[0], nil // TODO: decide if choosing the first decision is OK + } + + return nil, nil +} diff --git a/internal/bouncer/logging.go b/internal/bouncer/logging.go index a28e9255..fdbf81f2 100644 --- a/internal/bouncer/logging.go +++ b/internal/bouncer/logging.go @@ -38,6 +38,10 @@ func (b *Bouncer) overrideLogrusLogger() { std.ReplaceHooks(hooks) } +func (b *Bouncer) zapField() zapcore.Field { + return zap.String("instance_id", b.instanceID) +} + type zapAdapterHook struct { logger *zap.Logger shouldFailHard bool @@ -106,3 +110,7 @@ var levelAdapter = map[logrus.Level]zapcore.Level{ } var _ logrus.Hook = (*zapAdapterHook)(nil) + +func newMetricsLogger() *logrus.Logger { + return logrus.StandardLogger() +} diff --git a/internal/bouncer/metrics.go b/internal/bouncer/metrics.go new file mode 100644 index 00000000..8ed72f7b --- /dev/null +++ b/internal/bouncer/metrics.go @@ -0,0 +1,54 @@ +package bouncer + +import ( + "context" + "fmt" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/models" + csbouncer "github.com/crowdsecurity/go-cs-bouncer" + "github.com/crowdsecurity/go-cs-lib/ptr" + "go.uber.org/zap" +) + +func newMetricsProvider(client *apiclient.ApiClient, updater csbouncer.MetricsUpdater, interval time.Duration) (*csbouncer.MetricsProvider, error) { + m, err := csbouncer.NewMetricsProvider( + client, + userAgentName, + updater, + newMetricsLogger(), + ) + if err != nil { + return nil, fmt.Errorf("failed creating metrics provider: %w", err) + } + + m.Interval = interval + + return m, nil +} + +func (b *Bouncer) startMetricsProvider(ctx context.Context) { + b.wg.Add(1) + go func() { + defer b.wg.Done() + + b.logger.Debug("starting metrics provider", b.zapField()) + if err := b.metricsProvider.Run(ctx); err != nil { + if err.Error() == "metric provider halted" { + b.logger.Info("metrics provider stopped", b.zapField()) + } else { + b.logger.Error("failed running metrics provider", b.zapField(), zap.Error(err)) + } + } + }() +} + +func (b *Bouncer) updateMetrics(m *models.RemediationComponentsMetrics, interval time.Duration) { + m.Name = userAgentName // instance ID? Is name provided when creating bouncer in CrowdSec, it seems + m.Version = ptr.Of(userAgentVersion) + m.Type = userAgentName + m.UtcStartupTimestamp = ptr.Of(b.startedAt.UTC().Unix()) + + // TODO: add metrics +}