From e139f38c1f350bf9f28f05f258439429b191722d Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Wed, 28 Aug 2024 23:37:55 +0800 Subject: [PATCH] feat: service support persistence --- internal/core/dify_invocation/types.go | 26 ++++++ internal/core/persistence/init.go | 14 +++- internal/core/persistence/persistence.go | 22 ++--- internal/core/persistence/persistence_test.go | 16 ++-- .../backwards_invocation/task.go | 80 +++++++++++++++++-- internal/core/session_manager/session.go | 12 ++- internal/server/app.go | 4 - internal/server/server.go | 2 +- internal/service/endpoint.go | 14 +++- internal/service/invoke_model.go | 12 ++- internal/service/invoke_tool.go | 61 +++++++++++--- 11 files changed, 212 insertions(+), 51 deletions(-) diff --git a/internal/core/dify_invocation/types.go b/internal/core/dify_invocation/types.go index a282441..f80f4a4 100644 --- a/internal/core/dify_invocation/types.go +++ b/internal/core/dify_invocation/types.go @@ -1,8 +1,10 @@ package dify_invocation import ( + "github.com/go-playground/validator/v10" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/app_entities" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests" + "github.com/langgenius/dify-plugin-daemon/internal/types/validators" ) type BaseInvokeDifyRequest struct { @@ -23,6 +25,7 @@ const ( INVOKE_TYPE_TOOL InvokeType = "tool" INVOKE_TYPE_NODE InvokeType = "node" INVOKE_TYPE_APP InvokeType = "app" + INVOKE_TYPE_STORAGE InvokeType = "storage" ) type InvokeLLMRequest struct { @@ -71,6 +74,29 @@ type InvokeAppSchema struct { Files []*app_entities.FileVar `json:"files" validate:"omitempty,dive"` } +type StorageOpt string + +const ( + STORAGE_OPT_GET StorageOpt = "get" + STORAGE_OPT_SET StorageOpt = "set" + STORAGE_OPT_DEL StorageOpt = "del" +) + +func isStorageOpt(fl validator.FieldLevel) bool { + opt := StorageOpt(fl.Field().String()) + return opt == STORAGE_OPT_GET || opt == STORAGE_OPT_SET || opt == STORAGE_OPT_DEL +} + +func init() { + validators.GlobalEntitiesValidator.RegisterValidation("storage_opt", isStorageOpt) +} + +type InvokeStorageRequest struct { + Opt StorageOpt `json:"opt" validate:"required,storage_opt"` + Key string `json:"key" validate:"required"` + Value string `json:"value"` // encoded in hex, optional +} + type InvokeAppRequest struct { BaseInvokeDifyRequest requests.BaseRequestInvokeModel diff --git a/internal/core/persistence/init.go b/internal/core/persistence/init.go index cae7d75..cad15a7 100644 --- a/internal/core/persistence/init.go +++ b/internal/core/persistence/init.go @@ -5,7 +5,11 @@ import ( "github.com/langgenius/dify-plugin-daemon/internal/utils/log" ) -func InitPersistence(config *app.Config) *Persistence { +var ( + persistence *Persistence +) + +func InitPersistence(config *app.Config) { if config.PersistenceStorageType == "s3" { s3, err := NewS3Wrapper( config.PersistenceStorageS3Region, @@ -17,16 +21,18 @@ func InitPersistence(config *app.Config) *Persistence { log.Panic("Failed to initialize S3 wrapper: %v", err) } - return &Persistence{ + persistence = &Persistence{ storage: s3, } } else if config.PersistenceStorageType == "local" { - return &Persistence{ + persistence = &Persistence{ storage: NewLocalWrapper(config.PersistenceStorageLocalPath), } } else { log.Panic("Invalid persistence storage type: %s", config.PersistenceStorageType) } +} - return nil +func GetPersistence() *Persistence { + return persistence } diff --git a/internal/core/persistence/persistence.go b/internal/core/persistence/persistence.go index c90890c..3f1677b 100644 --- a/internal/core/persistence/persistence.go +++ b/internal/core/persistence/persistence.go @@ -16,21 +16,21 @@ const ( CACHE_KEY_PREFIX = "persistence:cache" ) -func (c *Persistence) getCacheKey(tenant_id string, plugin_checksum string, key string) string { - return fmt.Sprintf("%s:%s:%s:%s", CACHE_KEY_PREFIX, tenant_id, plugin_checksum, key) +func (c *Persistence) getCacheKey(tenant_id string, plugin_identity string, key string) string { + return fmt.Sprintf("%s:%s:%s:%s", CACHE_KEY_PREFIX, tenant_id, plugin_identity, key) } -func (c *Persistence) Save(tenant_id string, plugin_checksum string, key string, data []byte) error { +func (c *Persistence) Save(tenant_id string, plugin_identity string, key string, data []byte) error { if len(key) > 64 { return fmt.Errorf("key length must be less than 64 characters") } - return c.storage.Save(tenant_id, plugin_checksum, key, data) + return c.storage.Save(tenant_id, plugin_identity, key, data) } -func (c *Persistence) Load(tenant_id string, plugin_checksum string, key string) ([]byte, error) { +func (c *Persistence) Load(tenant_id string, plugin_identity string, key string) ([]byte, error) { // check if the key exists in cache - h, err := cache.GetString(c.getCacheKey(tenant_id, plugin_checksum, key)) + h, err := cache.GetString(c.getCacheKey(tenant_id, plugin_identity, key)) if err != nil && err != cache.ErrNotFound { return nil, err } @@ -39,22 +39,22 @@ func (c *Persistence) Load(tenant_id string, plugin_checksum string, key string) } // load from storage - data, err := c.storage.Load(tenant_id, plugin_checksum, key) + data, err := c.storage.Load(tenant_id, plugin_identity, key) if err != nil { return nil, err } // add to cache - cache.Store(c.getCacheKey(tenant_id, plugin_checksum, key), hex.EncodeToString(data), time.Minute*5) + cache.Store(c.getCacheKey(tenant_id, plugin_identity, key), hex.EncodeToString(data), time.Minute*5) return data, nil } -func (c *Persistence) Delete(tenant_id string, plugin_checksum string, key string) error { +func (c *Persistence) Delete(tenant_id string, plugin_identity string, key string) error { // delete from cache and storage - err := cache.Del(c.getCacheKey(tenant_id, plugin_checksum, key)) + err := cache.Del(c.getCacheKey(tenant_id, plugin_identity, key)) if err != nil { return err } - return c.storage.Delete(tenant_id, plugin_checksum, key) + return c.storage.Delete(tenant_id, plugin_identity, key) } diff --git a/internal/core/persistence/persistence_test.go b/internal/core/persistence/persistence_test.go index e08bea3..522eba4 100644 --- a/internal/core/persistence/persistence_test.go +++ b/internal/core/persistence/persistence_test.go @@ -17,18 +17,18 @@ func TestPersistenceStoreAndLoad(t *testing.T) { } defer cache.Close() - p := InitPersistence(&app.Config{ + InitPersistence(&app.Config{ PersistenceStorageType: "local", PersistenceStorageLocalPath: "./persistence_storage", }) key := strings.RandomString(10) - if err := p.Save("tenant_id", "plugin_checksum", key, []byte("data")); err != nil { + if err := persistence.Save("tenant_id", "plugin_checksum", key, []byte("data")); err != nil { t.Fatalf("Failed to save data: %v", err) } - data, err := p.Load("tenant_id", "plugin_checksum", key) + data, err := persistence.Load("tenant_id", "plugin_checksum", key) if err != nil { t.Fatalf("Failed to load data: %v", err) } @@ -65,14 +65,14 @@ func TestPersistenceSaveAndLoadWithLongKey(t *testing.T) { } defer cache.Close() - p := InitPersistence(&app.Config{ + InitPersistence(&app.Config{ PersistenceStorageType: "local", PersistenceStorageLocalPath: "./persistence_storage", }) key := strings.RandomString(65) - if err := p.Save("tenant_id", "plugin_checksum", key, []byte("data")); err == nil { + if err := persistence.Save("tenant_id", "plugin_checksum", key, []byte("data")); err == nil { t.Fatalf("Expected error, got nil") } } @@ -84,18 +84,18 @@ func TestPersistenceDelete(t *testing.T) { } defer cache.Close() - p := InitPersistence(&app.Config{ + InitPersistence(&app.Config{ PersistenceStorageType: "local", PersistenceStorageLocalPath: "./persistence_storage", }) key := strings.RandomString(10) - if err := p.Save("tenant_id", "plugin_checksum", key, []byte("data")); err != nil { + if err := persistence.Save("tenant_id", "plugin_checksum", key, []byte("data")); err != nil { t.Fatalf("Failed to save data: %v", err) } - if err := p.Delete("tenant_id", "plugin_checksum", key); err != nil { + if err := persistence.Delete("tenant_id", "plugin_checksum", key); err != nil { t.Fatalf("Failed to delete data: %v", err) } diff --git a/internal/core/plugin_daemon/backwards_invocation/task.go b/internal/core/plugin_daemon/backwards_invocation/task.go index ddd5145..a176aa7 100644 --- a/internal/core/plugin_daemon/backwards_invocation/task.go +++ b/internal/core/plugin_daemon/backwards_invocation/task.go @@ -1,6 +1,7 @@ package backwards_invocation import ( + "encoding/hex" "fmt" "github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation" @@ -169,28 +170,31 @@ func prepareDifyInvocationArguments( var ( dispatchMapping = map[dify_invocation.InvokeType]func(handle *BackwardsInvocation){ dify_invocation.INVOKE_TYPE_TOOL: func(handle *BackwardsInvocation) { - genericDispatchTask[dify_invocation.InvokeToolRequest](handle, executeDifyInvocationToolTask) + genericDispatchTask(handle, executeDifyInvocationToolTask) }, dify_invocation.INVOKE_TYPE_LLM: func(handle *BackwardsInvocation) { - genericDispatchTask[dify_invocation.InvokeLLMRequest](handle, executeDifyInvocationLLMTask) + genericDispatchTask(handle, executeDifyInvocationLLMTask) }, dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: func(handle *BackwardsInvocation) { - genericDispatchTask[dify_invocation.InvokeTextEmbeddingRequest](handle, executeDifyInvocationTextEmbeddingTask) + genericDispatchTask(handle, executeDifyInvocationTextEmbeddingTask) }, dify_invocation.INVOKE_TYPE_RERANK: func(handle *BackwardsInvocation) { - genericDispatchTask[dify_invocation.InvokeRerankRequest](handle, executeDifyInvocationRerankTask) + genericDispatchTask(handle, executeDifyInvocationRerankTask) }, dify_invocation.INVOKE_TYPE_TTS: func(handle *BackwardsInvocation) { - genericDispatchTask[dify_invocation.InvokeTTSRequest](handle, executeDifyInvocationTTSTask) + genericDispatchTask(handle, executeDifyInvocationTTSTask) }, dify_invocation.INVOKE_TYPE_SPEECH2TEXT: func(handle *BackwardsInvocation) { - genericDispatchTask[dify_invocation.InvokeSpeech2TextRequest](handle, executeDifyInvocationSpeech2TextTask) + genericDispatchTask(handle, executeDifyInvocationSpeech2TextTask) }, dify_invocation.INVOKE_TYPE_MODERATION: func(handle *BackwardsInvocation) { - genericDispatchTask[dify_invocation.InvokeModerationRequest](handle, executeDifyInvocationModerationTask) + genericDispatchTask(handle, executeDifyInvocationModerationTask) }, dify_invocation.INVOKE_TYPE_APP: func(handle *BackwardsInvocation) { - genericDispatchTask[dify_invocation.InvokeAppRequest](handle, executeDifyInvocationAppTask) + genericDispatchTask(handle, executeDifyInvocationAppTask) + }, + dify_invocation.INVOKE_TYPE_STORAGE: func(handle *BackwardsInvocation) { + genericDispatchTask(handle, executeDifyInvocationStorageTask) }, } ) @@ -356,3 +360,63 @@ func executeDifyInvocationAppTask( handle.WriteResponse("stream", t) }) } + +func executeDifyInvocationStorageTask( + handle *BackwardsInvocation, + request *dify_invocation.InvokeStorageRequest, +) { + if handle.session == nil { + handle.WriteError(fmt.Errorf("session not found")) + return + } + + persistence := handle.session.Storage() + if persistence == nil { + handle.WriteError(fmt.Errorf("persistence not found")) + return + } + + tenant_id, err := handle.TenantID() + if err != nil { + handle.WriteError(fmt.Errorf("get tenant id failed: %s", err.Error())) + return + } + + plugin_id := handle.session.PluginIdentity + + if request.Opt == dify_invocation.STORAGE_OPT_GET { + data, err := persistence.Load(tenant_id, plugin_id, request.Key) + if err != nil { + handle.WriteError(fmt.Errorf("load data failed: %s", err.Error())) + return + } + + handle.WriteResponse("struct", map[string]any{ + "data": hex.EncodeToString(data), + }) + } else if request.Opt == dify_invocation.STORAGE_OPT_SET { + data, err := hex.DecodeString(request.Value) + if err != nil { + handle.WriteError(fmt.Errorf("decode data failed: %s", err.Error())) + return + } + + if err := persistence.Save(tenant_id, plugin_id, request.Key, data); err != nil { + handle.WriteError(fmt.Errorf("save data failed: %s", err.Error())) + return + } + + handle.WriteResponse("struct", map[string]any{ + "data": "ok", + }) + } else if request.Opt == dify_invocation.STORAGE_OPT_DEL { + if err := persistence.Delete(tenant_id, plugin_id, request.Key); err != nil { + handle.WriteError(fmt.Errorf("delete data failed: %s", err.Error())) + return + } + + handle.WriteResponse("struct", map[string]any{ + "data": "ok", + }) + } +} diff --git a/internal/core/session_manager/session.go b/internal/core/session_manager/session.go index 9424788..3b59c27 100644 --- a/internal/core/session_manager/session.go +++ b/internal/core/session_manager/session.go @@ -7,6 +7,7 @@ import ( "time" "github.com/google/uuid" + "github.com/langgenius/dify-plugin-daemon/internal/core/persistence" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" "github.com/langgenius/dify-plugin-daemon/internal/utils/cache" @@ -21,8 +22,9 @@ var ( // session need to implement the backwards_invocation.BackwardsInvocationWriter interface type Session struct { - ID string `json:"id"` - runtime plugin_entities.PluginRuntimeInterface `json:"-"` + ID string `json:"id"` + runtime plugin_entities.PluginRuntimeInterface `json:"-"` + persistence *persistence.Persistence `json:"-"` TenantID string `json:"tenant_id"` UserID string `json:"user_id"` @@ -45,6 +47,7 @@ func NewSession( invoke_from access_types.PluginAccessType, action access_types.PluginAccessAction, declaration *plugin_entities.PluginDeclaration, + persistence *persistence.Persistence, ) *Session { s := &Session{ ID: uuid.New().String(), @@ -55,6 +58,7 @@ func NewSession( InvokeFrom: invoke_from, Action: action, Declaration: declaration, + persistence: persistence, } session_lock.Lock() @@ -97,6 +101,10 @@ func (s *Session) Runtime() plugin_entities.PluginRuntimeInterface { return s.runtime } +func (s *Session) Storage() *persistence.Persistence { + return s.persistence +} + type PLUGIN_IN_STREAM_EVENT string const ( diff --git a/internal/server/app.go b/internal/server/app.go index 42ee369..8ebb46c 100644 --- a/internal/server/app.go +++ b/internal/server/app.go @@ -2,7 +2,6 @@ package server import ( "github.com/langgenius/dify-plugin-daemon/internal/cluster" - "github.com/langgenius/dify-plugin-daemon/internal/core/persistence" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation/transaction" ) @@ -18,7 +17,4 @@ type App struct { // aws transaction handler // accept aws transaction request and forward to the plugin daemon aws_transaction_handler *transaction.AWSTransactionHandler - - // persistence - persistence *persistence.Persistence } diff --git a/internal/server/server.go b/internal/server/server.go index fbb7a58..fdcecfa 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -26,7 +26,7 @@ func (a *App) Run(config *app.Config) { plugin_manager.InitGlobalPluginManager(a.cluster, config) // init persistence - a.persistence = persistence.InitPersistence(config) + persistence.InitPersistence(config) // launch cluster a.cluster.Launch() diff --git a/internal/service/endpoint.go b/internal/service/endpoint.go index 6e6c5a3..ba4faf9 100644 --- a/internal/service/endpoint.go +++ b/internal/service/endpoint.go @@ -8,6 +8,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/langgenius/dify-plugin-daemon/internal/core/persistence" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager" @@ -17,7 +18,11 @@ import ( "github.com/langgenius/dify-plugin-daemon/internal/utils/routine" ) -func Endpoint(ctx *gin.Context, endpoint *models.Endpoint, path string) { +func Endpoint( + ctx *gin.Context, + endpoint *models.Endpoint, + path string, +) { req := ctx.Request.Clone(context.Background()) req.URL.Path = path @@ -36,6 +41,12 @@ func Endpoint(ctx *gin.Context, endpoint *models.Endpoint, path string) { return } + persistence := persistence.GetPersistence() + if persistence == nil { + ctx.JSON(500, gin.H{"error": "persistence not found"}) + return + } + session := session_manager.NewSession( endpoint.TenantID, "", @@ -44,6 +55,7 @@ func Endpoint(ctx *gin.Context, endpoint *models.Endpoint, path string) { access_types.PLUGIN_ACCESS_TYPE_Endpoint, access_types.PLUGIN_ACCESS_ACTION_INVOKE_ENDPOINT, runtime.Configuration(), + persistence, ) defer session.Close() diff --git a/internal/service/invoke_model.go b/internal/service/invoke_model.go index 7ab691f..d1a409b 100644 --- a/internal/service/invoke_model.go +++ b/internal/service/invoke_model.go @@ -16,12 +16,16 @@ func InvokeTool( max_timeout_seconds int, ) { // create session - session := createSession( + session, err := createSession( r, access_types.PLUGIN_ACCESS_TYPE_TOOL, access_types.PLUGIN_ACCESS_ACTION_INVOKE_TOOL, ctx.GetString("cluster_id"), ) + if err != nil { + ctx.JSON(500, gin.H{"error": err.Error()}) + return + } defer session.Close() baseSSEService( @@ -39,12 +43,16 @@ func ValidateToolCredentials( max_timeout_seconds int, ) { // create session - session := createSession( + session, err := createSession( r, access_types.PLUGIN_ACCESS_TYPE_TOOL, access_types.PLUGIN_ACCESS_ACTION_VALIDATE_TOOL_CREDENTIALS, ctx.GetString("cluster_id"), ) + if err != nil { + ctx.JSON(500, gin.H{"error": err.Error()}) + return + } defer session.Close() baseSSEService( diff --git a/internal/service/invoke_tool.go b/internal/service/invoke_tool.go index d66a00a..ccfbcc2 100644 --- a/internal/service/invoke_tool.go +++ b/internal/service/invoke_tool.go @@ -1,7 +1,10 @@ package service import ( + "errors" + "github.com/gin-gonic/gin" + "github.com/langgenius/dify-plugin-daemon/internal/core/persistence" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager" @@ -18,7 +21,12 @@ func createSession[T any]( access_type access_types.PluginAccessType, access_action access_types.PluginAccessAction, cluster_id string, -) *session_manager.Session { +) (*session_manager.Session, error) { + persistence := persistence.GetPersistence() + if persistence == nil { + return nil, errors.New("persistence not found") + } + plugin_identity := parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion) runtime := plugin_manager.GetGlobalPluginManager().Get(plugin_identity) @@ -30,10 +38,11 @@ func createSession[T any]( access_type, access_action, runtime.Configuration(), + persistence, ) session.BindRuntime(runtime) - return session + return session, nil } func InvokeLLM( @@ -42,12 +51,16 @@ func InvokeLLM( max_timeout_seconds int, ) { // create session - session := createSession( + session, err := createSession( r, access_types.PLUGIN_ACCESS_TYPE_MODEL, access_types.PLUGIN_ACCESS_ACTION_INVOKE_LLM, ctx.GetString("cluster_id"), ) + if err != nil { + ctx.JSON(500, gin.H{"error": err.Error()}) + return + } defer session.Close() baseSSEService( @@ -65,11 +78,15 @@ func InvokeTextEmbedding( max_timeout_seconds int, ) { // create session - session := createSession( + session, err := createSession( r, access_types.PLUGIN_ACCESS_TYPE_MODEL, access_types.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING, ctx.GetString("cluster_id")) + if err != nil { + ctx.JSON(500, gin.H{"error": err.Error()}) + return + } defer session.Close() baseSSEService( @@ -87,12 +104,16 @@ func InvokeRerank( max_timeout_seconds int, ) { // create session - session := createSession( + session, err := createSession( r, access_types.PLUGIN_ACCESS_TYPE_MODEL, access_types.PLUGIN_ACCESS_ACTION_INVOKE_RERANK, ctx.GetString("cluster_id"), ) + if err != nil { + ctx.JSON(500, gin.H{"error": err.Error()}) + return + } defer session.Close() baseSSEService( @@ -110,12 +131,16 @@ func InvokeTTS( max_timeout_seconds int, ) { // create session - session := createSession( + session, err := createSession( r, access_types.PLUGIN_ACCESS_TYPE_MODEL, access_types.PLUGIN_ACCESS_ACTION_INVOKE_TTS, ctx.GetString("cluster_id"), ) + if err != nil { + ctx.JSON(500, gin.H{"error": err.Error()}) + return + } defer session.Close() baseSSEService( @@ -133,12 +158,16 @@ func InvokeSpeech2Text( max_timeout_seconds int, ) { // create session - session := createSession( + session, err := createSession( r, access_types.PLUGIN_ACCESS_TYPE_MODEL, access_types.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT, ctx.GetString("cluster_id"), ) + if err != nil { + ctx.JSON(500, gin.H{"error": err.Error()}) + return + } defer session.Close() baseSSEService( @@ -156,12 +185,16 @@ func InvokeModeration( max_timeout_seconds int, ) { // create session - session := createSession( + session, err := createSession( r, access_types.PLUGIN_ACCESS_TYPE_MODEL, access_types.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION, ctx.GetString("cluster_id"), ) + if err != nil { + ctx.JSON(500, gin.H{"error": err.Error()}) + return + } defer session.Close() baseSSEService( @@ -179,12 +212,16 @@ func ValidateProviderCredentials( max_timeout_seconds int, ) { // create session - session := createSession( + session, err := createSession( r, access_types.PLUGIN_ACCESS_TYPE_MODEL, access_types.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS, ctx.GetString("cluster_id"), ) + if err != nil { + ctx.JSON(500, gin.H{"error": err.Error()}) + return + } defer session.Close() baseSSEService( @@ -202,12 +239,16 @@ func ValidateModelCredentials( max_timeout_seconds int, ) { // create session - session := createSession( + session, err := createSession( r, access_types.PLUGIN_ACCESS_TYPE_MODEL, access_types.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS, ctx.GetString("cluster_id"), ) + if err != nil { + ctx.JSON(500, gin.H{"error": err.Error()}) + return + } defer session.Close() baseSSEService(