diff --git a/automod/engine.go b/automod/engine.go index 4a3509ff9..fe7bc26dc 100644 --- a/automod/engine.go +++ b/automod/engine.go @@ -22,6 +22,7 @@ type Engine struct { Counters CountStore Sets SetStore Cache CacheStore + Flags FlagStore RelayClient *xrpc.Client BskyClient *xrpc.Client // used to persist moderation actions in mod service (optional) diff --git a/automod/engine_test.go b/automod/engine_test.go index daa68d93a..25994cd9a 100644 --- a/automod/engine_test.go +++ b/automod/engine_test.go @@ -4,6 +4,7 @@ import ( "context" "log/slog" "testing" + "time" appbsky "github.com/bluesky-social/indigo/api/bsky" "github.com/bluesky-social/indigo/atproto/identity" @@ -39,6 +40,8 @@ func engineFixture() Engine { simpleRule, }, } + cache := NewMemCacheStore(10, time.Hour) + flags := NewMemFlagStore() sets := NewMemSetStore() sets.Sets["bad-hashtags"] = make(map[string]bool) sets.Sets["bad-hashtags"]["slur"] = true @@ -53,6 +56,8 @@ func engineFixture() Engine { Directory: &dir, Counters: NewMemCountStore(), Sets: sets, + Flags: flags, + Cache: cache, Rules: rules, } return engine diff --git a/automod/flagstore.go b/automod/flagstore.go new file mode 100644 index 000000000..b61402781 --- /dev/null +++ b/automod/flagstore.go @@ -0,0 +1,66 @@ +package automod + +import ( + "context" +) + +type FlagStore interface { + Get(ctx context.Context, key string) ([]string, error) + Add(ctx context.Context, key string, flags []string) error + Remove(ctx context.Context, key string, flags []string) error +} + +type MemFlagStore struct { + Data map[string][]string +} + +func NewMemFlagStore() MemFlagStore { + return MemFlagStore{ + Data: make(map[string][]string), + } +} + +func (s MemFlagStore) Get(ctx context.Context, key string) ([]string, error) { + v, ok := s.Data[key] + if !ok { + return []string{}, nil + } + return v, nil +} + +func (s MemFlagStore) Add(ctx context.Context, key string, flags []string) error { + v, ok := s.Data[key] + if !ok { + v = []string{} + } + for _, f := range flags { + v = append(v, f) + } + v = dedupeStrings(v) + s.Data[key] = v + return nil +} + +// does not error if flags not in set +func (s MemFlagStore) Remove(ctx context.Context, key string, flags []string) error { + if len(flags) == 0 { + return nil + } + v, ok := s.Data[key] + if !ok { + v = []string{} + } + m := make(map[string]bool, len(v)) + for _, f := range v { + m[f] = true + } + for _, f := range flags { + delete(m, f) + } + out := []string{} + for f, _ := range m { + out = append(out, f) + } + s.Data[key] = out + return nil +} diff --git a/automod/flagstore_test.go b/automod/flagstore_test.go new file mode 100644 index 000000000..a64ac67df --- /dev/null +++ b/automod/flagstore_test.go @@ -0,0 +1,30 @@ +package automod + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFlagStoreBasics(t *testing.T) { + assert := assert.New(t) + ctx := context.Background() + + fs := NewMemFlagStore() + + l, err := fs.Get(ctx, "test1") + assert.NoError(err) + assert.Empty(l) + + assert.NoError(fs.Add(ctx, "test1", []string{"red", "green"})) + assert.NoError(fs.Add(ctx, "test1", []string{"red", "blue"})) + l, err = fs.Get(ctx, "test1") + assert.NoError(err) + assert.Equal(3, len(l)) + + assert.NoError(fs.Remove(ctx, "test1", []string{"red", "blue"})) + l, err = fs.Get(ctx, "test1") + assert.NoError(err) + assert.Equal([]string{"green"}, l) +} diff --git a/automod/redis_flags.go b/automod/redis_flags.go new file mode 100644 index 000000000..83a37e68d --- /dev/null +++ b/automod/redis_flags.go @@ -0,0 +1,65 @@ +package automod + +import ( + "context" + + "github.com/redis/go-redis/v9" +) + +var redisFlagsPrefix string = "flags/" + +type RedisFlagStore struct { + Client *redis.Client +} + +func NewRedisFlagStore(redisURL string) (*RedisFlagStore, error) { + opt, err := redis.ParseURL(redisURL) + if err != nil { + return nil, err + } + rdb := redis.NewClient(opt) + // check redis connection + _, err = rdb.Ping(context.TODO()).Result() + if err != nil { + return nil, err + } + rcs := RedisFlagStore{ + Client: rdb, + } + return &rcs, nil +} + +func (s *RedisFlagStore) Get(ctx context.Context, key string) ([]string, error) { + rkey := redisFlagsPrefix + key + l, err := s.Client.SMembers(ctx, rkey).Result() + if err == redis.Nil { + return []string{}, nil + } else if err != nil { + return nil, err + } + return l, nil +} + +func (s *RedisFlagStore) Add(ctx context.Context, key string, flags []string) error { + if len(flags) == 0 { + return nil + } + l := []interface{}{} + for _, v := range flags { + l = append(l, v) + } + rkey := redisFlagsPrefix + key + return s.Client.SAdd(ctx, rkey, l...).Err() +} + +func (s *RedisFlagStore) Remove(ctx context.Context, key string, flags []string) error { + if len(flags) == 0 { + return nil + } + l := []interface{}{} + for _, v := range flags { + l = append(l, v) + } + rkey := redisFlagsPrefix + key + return s.Client.SRem(ctx, rkey, l...).Err() +} diff --git a/automod/redis_flagstore_test.go b/automod/redis_flagstore_test.go new file mode 100644 index 000000000..a295baa4f --- /dev/null +++ b/automod/redis_flagstore_test.go @@ -0,0 +1,35 @@ +package automod + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRedisFlagStoreBasics(t *testing.T) { + t.Skip("live test, need redis running locally") + assert := assert.New(t) + ctx := context.Background() + + fs, err := NewRedisFlagStore("redis://localhost:6379/0") + if err != nil { + t.Fail() + } + + l, err := fs.Get(ctx, "test1") + assert.NoError(err) + assert.Empty(l) + + assert.NoError(fs.Add(ctx, "test1", []string{"red", "green"})) + assert.NoError(fs.Add(ctx, "test1", []string{"red", "blue"})) + l, err = fs.Get(ctx, "test1") + assert.NoError(err) + assert.Equal(3, len(l)) + + assert.NoError(fs.Remove(ctx, "test1", []string{"red", "blue", "orange"})) + l, err = fs.Get(ctx, "test1") + assert.NoError(err) + assert.Equal([]string{"green"}, l) + assert.NoError(fs.Remove(ctx, "test1", []string{"green"})) +} diff --git a/automod/rules/fixture_test.go b/automod/rules/fixture_test.go index a97ebbb99..d092c8406 100644 --- a/automod/rules/fixture_test.go +++ b/automod/rules/fixture_test.go @@ -2,6 +2,7 @@ package rules import ( "log/slog" + "time" "github.com/bluesky-social/indigo/atproto/identity" "github.com/bluesky-social/indigo/atproto/syntax" @@ -15,6 +16,8 @@ func engineFixture() automod.Engine { BadHashtagsPostRule, }, } + flags := automod.NewMemFlagStore() + cache := automod.NewMemCacheStore(10, time.Hour) sets := automod.NewMemSetStore() sets.Sets["bad-hashtags"] = make(map[string]bool) sets.Sets["bad-hashtags"]["slur"] = true @@ -37,6 +40,8 @@ func engineFixture() automod.Engine { Directory: &dir, Counters: automod.NewMemCountStore(), Sets: sets, + Flags: flags, + Cache: cache, Rules: rules, AdminClient: &adminc, } diff --git a/cmd/hepa/server.go b/cmd/hepa/server.go index f8954b667..560b85fe7 100644 --- a/cmd/hepa/server.go +++ b/cmd/hepa/server.go @@ -88,6 +88,7 @@ func NewServer(dir identity.Directory, config Config) (*Server, error) { var counters automod.CountStore var cache automod.CacheStore + var flags automod.FlagStore var rdb *redis.Client if config.RedisURL != "" { // generic client, for cursor state @@ -113,9 +114,16 @@ func NewServer(dir identity.Directory, config Config) (*Server, error) { return nil, err } cache = csh + + flg, err := automod.NewRedisFlagStore(config.RedisURL) + if err != nil { + return nil, err + } + flags = flg } else { counters = automod.NewMemCountStore() cache = automod.NewMemCacheStore(5_000, 30*time.Minute) + flags = automod.NewMemFlagStore() } engine := automod.Engine{ @@ -123,6 +131,7 @@ func NewServer(dir identity.Directory, config Config) (*Server, error) { Directory: dir, Counters: counters, Sets: sets, + Flags: flags, Cache: cache, Rules: rules.DefaultRules(), AdminClient: xrpcc,