Skip to content

Commit

Permalink
refactor: support aws lambda using persistence storage
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Aug 28, 2024
1 parent 1c9152f commit 1115a3f
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 37 deletions.
3 changes: 2 additions & 1 deletion internal/core/plugin_daemon/backwards_invocation/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"

"github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
"github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
Expand Down Expand Up @@ -370,7 +371,7 @@ func executeDifyInvocationStorageTask(
return
}

persistence := handle.session.Storage()
persistence := persistence.GetPersistence()
if persistence == nil {
handle.WriteError(fmt.Errorf("persistence not found"))
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ func (h *AWSTransactionHandler) Handle(
}

session := session_manager.GetSession(session_id)
if err != nil {
log.Error("get session failed: %s", err.Error())
if session == nil {
log.Error("session not found: %s", session_id)
writer.WriteHeader(http.StatusInternalServerError)
writer.Write([]byte(err.Error()))
writer.Write([]byte("session not found"))
return
}

Expand Down
21 changes: 13 additions & 8 deletions internal/core/session_manager/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ func NewSession(
invoke_from access_types.PluginAccessType,
action access_types.PluginAccessAction,
declaration *plugin_entities.PluginDeclaration,
persistence *persistence.Persistence,
) *Session {
s := &Session{
ID: uuid.New().String(),
Expand All @@ -58,7 +57,6 @@ func NewSession(
InvokeFrom: invoke_from,
Action: action,
Declaration: declaration,
persistence: persistence,
}

session_lock.Lock()
Expand All @@ -74,9 +72,20 @@ func NewSession(

func GetSession(id string) *Session {
session_lock.RLock()
defer session_lock.RUnlock()
session := sessions[id]
session_lock.RUnlock()

if session == nil {
// if session not found, it may be generated by another node, try to get it from cache
session, err := cache.Get[Session](sessionKey(id))
if err != nil {
log.Error("get session info from cache failed, %s", err)
return nil
}
return session
}

return sessions[id]
return session
}

func DeleteSession(id string) {
Expand All @@ -101,10 +110,6 @@ func (s *Session) Runtime() plugin_entities.PluginRuntimeInterface {
return s.runtime
}

func (s *Session) Storage() *persistence.Persistence {
return s.persistence
}

type PLUGIN_IN_STREAM_EVENT string

const (
Expand Down
8 changes: 0 additions & 8 deletions internal/service/aws_transaction.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
package service

import (
"net/http"

"github.com/gin-gonic/gin"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation/transaction"
"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
)

func HandleAWSPluginTransaction(handler *transaction.AWSTransactionHandler) gin.HandlerFunc {
return func(c *gin.Context) {
// get session id from the context
session_id := c.Request.Header.Get("Dify-Plugin-Session-ID")
session := session_manager.GetSession(session_id)
if session == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "session not found"})
return
}

handler.Handle(c, session_id)
}
Expand Down
8 changes: 0 additions & 8 deletions internal/service/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"time"

"github.com/gin-gonic/gin"
"github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
Expand Down Expand Up @@ -41,12 +40,6 @@ func Endpoint(
return
}

persistence := persistence.GetPersistence()
if persistence == nil {
ctx.JSON(500, gin.H{"error": "persistence not found"})
return
}

session := session_manager.NewSession(
endpoint.TenantID,
"",
Expand All @@ -55,7 +48,6 @@ func Endpoint(
access_types.PLUGIN_ACCESS_TYPE_Endpoint,
access_types.PLUGIN_ACCESS_ACTION_INVOKE_ENDPOINT,
runtime.Configuration(),
persistence,
)
defer session.Close()

Expand Down
9 changes: 0 additions & 9 deletions internal/service/invoke_tool.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package service

import (
"errors"

"github.com/gin-gonic/gin"
"github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
Expand All @@ -22,11 +19,6 @@ func createSession[T any](
access_action access_types.PluginAccessAction,
cluster_id string,
) (*session_manager.Session, error) {
persistence := persistence.GetPersistence()
if persistence == nil {
return nil, errors.New("persistence not found")
}

plugin_identity := parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)
runtime := plugin_manager.GetGlobalPluginManager().Get(plugin_identity)

Expand All @@ -38,7 +30,6 @@ func createSession[T any](
access_type,
access_action,
runtime.Configuration(),
persistence,
)

session.BindRuntime(runtime)
Expand Down

0 comments on commit 1115a3f

Please sign in to comment.