Skip to content

Commit

Permalink
Merge pull request #73 from gatewayd-io/66-write-to-cache-async
Browse files Browse the repository at this point in the history
Update Redis cache asynhronously
  • Loading branch information
mostafa authored Dec 19, 2023
2 parents aa362ef + 726cf90 commit 99877a9
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 70 deletions.
1 change: 1 addition & 0 deletions gatewayd_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ plugins:
- API_ADDRESS=localhost:18080
- EXIT_ON_STARTUP_ERROR=False
- SENTRY_DSN=https://70eb1abcd32e41acbdfc17bc3407a543@o4504550475038720.ingest.sentry.io/4505342961123328
- CACHE_CHANNEL_BUFFER_SIZE=100
checksum: 3988e10aefce2cd9b30888eddd2ec93a431c9018a695aea1cea0dac46ba91cae
11 changes: 11 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/gatewayd-io/gatewayd-plugin-sdk/logging"
"github.com/gatewayd-io/gatewayd-plugin-sdk/metrics"
p "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin"
v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
"github.com/getsentry/sentry-go"
"github.com/go-redis/redis/v8"
"github.com/hashicorp/go-hclog"
Expand Down Expand Up @@ -52,6 +53,14 @@ func main() {
go metrics.ExposeMetrics(metricsConfig, logger)
}

cacheBufferSize := cast.ToUint(cfg["cacheBufferSize"])
if cacheBufferSize <= 0 {
cacheBufferSize = 100 // default value
}

pluginInstance.Impl.UpdateCacheChannel = make(chan *v1.Struct, cacheBufferSize)
go pluginInstance.Impl.UpdateCache(context.Background())

pluginInstance.Impl.RedisURL = cast.ToString(cfg["redisURL"])
pluginInstance.Impl.Expiry = cast.ToDuration(cfg["expiry"])
pluginInstance.Impl.DefaultDBName = cast.ToString(cfg["defaultDBName"])
Expand Down Expand Up @@ -93,6 +102,8 @@ func main() {
}
}

defer close(pluginInstance.Impl.UpdateCacheChannel)

goplugin.Serve(&goplugin.ServeConfig{
HandshakeConfig: goplugin.HandshakeConfig{
ProtocolVersion: 1,
Expand Down
1 change: 1 addition & 0 deletions plugin/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ var (
"PERIODIC_INVALIDATOR_INTERVAL", "1m"),
"apiAddress": sdkConfig.GetEnv("API_ADDRESS", "localhost:8080"),
"exitOnStartupError": sdkConfig.GetEnv("EXIT_ON_STARTUP_ERROR", "false"),
"cacheBufferSize": sdkConfig.GetEnv("CACHE_CHANNEL_BUFFER_SIZE", "100"),
},
"hooks": []interface{}{
int32(v1.HookName_HOOK_NAME_ON_CLOSED),
Expand Down
146 changes: 82 additions & 64 deletions plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ type Plugin struct {
ScanCount int64
ExitOnStartupError bool

UpdateCacheChannel chan *v1.Struct

// Periodic invalidator configuration.
PeriodicInvalidatorEnabled bool
PeriodicInvalidatorStartDelay time.Duration
Expand Down Expand Up @@ -144,87 +146,103 @@ func (p *Plugin) OnTrafficFromClient(
return req, nil
}

// OnTrafficFromServer is called when a response is received by GatewayD from the server.
func (p *Plugin) OnTrafficFromServer(
ctx context.Context, resp *v1.Struct,
) (*v1.Struct, error) {
OnTrafficFromServerCounter.Inc()
resp, err := postgres.HandleServerMessage(resp, p.Logger)
if err != nil {
p.Logger.Info("Failed to handle server message", "error", err)
}

rowDescription := cast.ToString(sdkPlugin.GetAttr(resp, "rowDescription", ""))
dataRow := cast.ToStringSlice(sdkPlugin.GetAttr(resp, "dataRow", []interface{}{}))
errorResponse := cast.ToString(sdkPlugin.GetAttr(resp, "errorResponse", ""))
request, ok := sdkPlugin.GetAttr(resp, "request", nil).([]byte)
if !ok {
request = []byte{}
}
response, ok := sdkPlugin.GetAttr(resp, "response", nil).([]byte)
if !ok {
response = []byte{}
}
server := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "server", ""))
func (p *Plugin) UpdateCache(ctx context.Context) {
for {
serverResponse, ok := <-p.UpdateCacheChannel
if !ok {
p.Logger.Info("Channel closed, returning from function")
return
}

// This is used as a fallback if the database is not found in the startup message.
database := p.DefaultDBName
if database == "" {
client := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "client", ""))
if client != nil && client["remote"] != "" {
database, err = p.RedisClient.Get(ctx, client["remote"]).Result()
if err != nil {
CacheMissesCounter.Inc()
p.Logger.Debug("Failed to get cached response", "error", err)
}
CacheGetsCounter.Inc()
OnTrafficFromServerCounter.Inc()
resp, err := postgres.HandleServerMessage(serverResponse, p.Logger)
if err != nil {
p.Logger.Info("Failed to handle server message", "error", err)
}
}

// If the database is still not found, return the response as is without caching.
// This might also happen if the cache is cleared while the client is still connected.
// In this case, the client should reconnect and the error will go away.
if database == "" {
p.Logger.Debug("Database name not found or set in cache, startup message or plugin config. Skipping cache")
p.Logger.Debug("Consider setting the database name in the plugin config or disabling the plugin if you don't need it")
return resp, nil
}
rowDescription := cast.ToString(sdkPlugin.GetAttr(resp, "rowDescription", ""))
dataRow := cast.ToStringSlice(sdkPlugin.GetAttr(resp, "dataRow", []interface{}{}))
errorResponse := cast.ToString(sdkPlugin.GetAttr(resp, "errorResponse", ""))
request, isOk := sdkPlugin.GetAttr(resp, "request", nil).([]byte)
if !isOk {
request = []byte{}
}

cacheKey := strings.Join([]string{server["remote"], database, string(request)}, ":")
if errorResponse == "" && rowDescription != "" && dataRow != nil && len(dataRow) > 0 {
// The request was successful and the response contains data. Cache the response.
if err := p.RedisClient.Set(ctx, cacheKey, response, p.Expiry).Err(); err != nil {
CacheMissesCounter.Inc()
p.Logger.Debug("Failed to set cache", "error", err)
response, isOk := sdkPlugin.GetAttr(resp, "response", nil).([]byte)
if !isOk {
response = []byte{}
}
CacheSetsCounter.Inc()
server := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "server", ""))

// Cache the query as well.
query, err := postgres.GetQueryFromRequest(request)
if err != nil {
p.Logger.Debug("Failed to get query from request", "error", err)
return resp, nil
// This is used as a fallback if the database is not found in the startup message.

database := p.DefaultDBName
if database == "" {
client := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "client", ""))
if client != nil && client["remote"] != "" {
database, err = p.RedisClient.Get(ctx, client["remote"]).Result()
if err != nil {
CacheMissesCounter.Inc()
p.Logger.Debug("Failed to get cached response", "error", err)
}
CacheGetsCounter.Inc()
}
}

tables, err := postgres.GetTablesFromQuery(query)
if err != nil {
p.Logger.Debug("Failed to get tables from query", "error", err)
return resp, nil
// If the database is still not found, return the response as is without caching.
// This might also happen if the cache is cleared while the client is still connected.
// In this case, the client should reconnect and the error will go away.
if database == "" {
p.Logger.Debug("Database name not found or set in cache, startup message or plugin config. " +
"Skipping cache")
p.Logger.Debug("Consider setting the database name in the " +
"plugin config or disabling the plugin if you don't need it")
return
}

// Cache the table(s) used in each cached request. This is used to invalidate
// the cache when a rows is inserted, updated or deleted into that table.
for _, table := range tables {
requestQueryCacheKey := strings.Join([]string{table, cacheKey}, ":")
if err := p.RedisClient.Set(
ctx, requestQueryCacheKey, "", p.Expiry).Err(); err != nil {
cacheKey := strings.Join([]string{server["remote"], database, string(request)}, ":")
if errorResponse == "" && rowDescription != "" && dataRow != nil && len(dataRow) > 0 {
// The request was successful and the response contains data. Cache the response.
if err := p.RedisClient.Set(ctx, cacheKey, response, p.Expiry).Err(); err != nil {
CacheMissesCounter.Inc()
p.Logger.Debug("Failed to set cache", "error", err)
}
CacheSetsCounter.Inc()

// Cache the query as well.
query, err := postgres.GetQueryFromRequest(request)
if err != nil {
p.Logger.Debug("Failed to get query from request", "error", err)
return
}

tables, err := postgres.GetTablesFromQuery(query)
if err != nil {
p.Logger.Debug("Failed to get tables from query", "error", err)
return
}

// Cache the table(s) used in each cached request. This is used to invalidate
// the cache when a rows is inserted, updated or deleted into that table.
for _, table := range tables {
requestQueryCacheKey := strings.Join([]string{table, cacheKey}, ":")
if err := p.RedisClient.Set(
ctx, requestQueryCacheKey, "", p.Expiry).Err(); err != nil {
CacheMissesCounter.Inc()
p.Logger.Debug("Failed to set cache", "error", err)
}
CacheSetsCounter.Inc()
}
}
}
}

// OnTrafficFromServer is called when a response is received by GatewayD from the server.
func (p *Plugin) OnTrafficFromServer(
_ context.Context, resp *v1.Struct,
) (*v1.Struct, error) {
p.Logger.Debug("Traffic is coming from the server side")
p.UpdateCacheChannel <- resp
return resp, nil
}

Expand Down
28 changes: 22 additions & 6 deletions plugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ package plugin
import (
"context"
"encoding/base64"
"os"
"testing"

miniredis "github.com/alicebob/miniredis/v2"
"github.com/gatewayd-io/gatewayd-plugin-sdk/logging"
v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
"github.com/go-redis/redis/v8"
"github.com/hashicorp/go-hclog"
pgproto3 "github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/assert"
"os"
"sync"
"testing"
)

func testQueryRequest() (string, []byte) {
Expand Down Expand Up @@ -44,16 +44,28 @@ func Test_Plugin(t *testing.T) {
redisClient := redis.NewClient(redisConfig)
assert.NotNil(t, redisClient)

updateCacheChannel := make(chan *v1.Struct, 10)

// Create and initialize a new plugin.
logger := hclog.New(&hclog.LoggerOptions{
Level: logging.GetLogLevel("error"),
Output: os.Stdout,
})
p := NewCachePlugin(Plugin{
Logger: logger,
RedisURL: redisURL,
RedisClient: redisClient,
Logger: logger,
RedisURL: redisURL,
RedisClient: redisClient,
UpdateCacheChannel: updateCacheChannel,
})

// Use a WaitGroup to wait for the goroutine to finish
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
p.Impl.UpdateCache(context.Background())
}()

assert.NotNil(t, p)

// Test the plugin's GetPluginConfig method.
Expand Down Expand Up @@ -146,6 +158,10 @@ func Test_Plugin(t *testing.T) {
assert.NotNil(t, result)
assert.Equal(t, result, resp)

// Close the channel and wait for the cache updater to return gracefully
close(updateCacheChannel)
wg.Wait()

// Check that the query and response was cached.
cachedResponse, err := redisClient.Get(
context.Background(), "localhost:5432:postgres:"+string(request)).Bytes()
Expand Down

0 comments on commit 99877a9

Please sign in to comment.