Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Sep 4, 2024
1 parent 7c0f8a8 commit 3603a9c
Show file tree
Hide file tree
Showing 31 changed files with 147 additions and 136 deletions.
4 changes: 2 additions & 2 deletions internal/cluster/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func (c *Cluster) autoGCPlugins() error {
)
}

func (c *Cluster) IsPluginNoCurrentNode(identity string) bool {
_, ok := c.plugins.Load(identity)
func (c *Cluster) IsPluginNoCurrentNode(identity plugin_entities.PluginUniqueIdentifier) bool {
_, ok := c.plugins.Load(identity.String())
return ok
}
4 changes: 2 additions & 2 deletions internal/cluster/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func (r *fakePlugin) Checksum() string {
return ""
}

func (r *fakePlugin) Identity() (plugin_entities.PluginIdentity, error) {
return plugin_entities.PluginIdentity(""), nil
func (r *fakePlugin) Identity() (plugin_entities.PluginUniqueIdentifier, error) {
return plugin_entities.PluginUniqueIdentifier(""), nil
}

func (r *fakePlugin) StartPlugin() error {
Expand Down
2 changes: 1 addition & 1 deletion internal/core/plugin_daemon/backwards_invocation/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func executeDifyInvocationStorageTask(
return
}

plugin_id := handle.session.PluginIdentity
plugin_id := handle.session.PluginUniqueIdentifier

if request.Opt == dify_invocation.STORAGE_OPT_GET {
data, err := persistence.Load(tenant_id, plugin_id.PluginID(), request.Key)
Expand Down
2 changes: 1 addition & 1 deletion internal/core/plugin_daemon/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func genericInvokePlugin[Req any, Rsp any](
request *Req,
response_buffer_size int,
) (*stream.StreamResponse[Rsp], error) {
runtime := plugin_manager.GetGlobalPluginManager().Get(session.PluginIdentity.String())
runtime := plugin_manager.GetGlobalPluginManager().Get(session.PluginUniqueIdentifier)
if runtime == nil {
return nil, errors.New("plugin not found")
}
Expand Down
4 changes: 2 additions & 2 deletions internal/core/plugin_manager/aws_manager/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ func (r *AWSPluginRuntime) InitEnvironment() error {
return nil
}

func (r *AWSPluginRuntime) Identity() (plugin_entities.PluginIdentity, error) {
return plugin_entities.PluginIdentity(fmt.Sprintf("%s@%s", r.Config.Identity(), r.Checksum())), nil
func (r *AWSPluginRuntime) Identity() (plugin_entities.PluginUniqueIdentifier, error) {
return plugin_entities.PluginUniqueIdentifier(fmt.Sprintf("%s@%s", r.Config.Identity(), r.Checksum())), nil
}

func (r *AWSPluginRuntime) initEnvironment() error {
Expand Down
4 changes: 2 additions & 2 deletions internal/core/plugin_manager/aws_manager/packager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ func (r *TPluginRuntime) Checksum() string {
return "test_checksum"
}

func (r *TPluginRuntime) Identity() (plugin_entities.PluginIdentity, error) {
return plugin_entities.PluginIdentity("test_identity"), nil
func (r *TPluginRuntime) Identity() (plugin_entities.PluginUniqueIdentifier, error) {
return plugin_entities.PluginUniqueIdentifier("test_identity"), nil
}

func (r *TPluginRuntime) StartPlugin() error {
Expand Down
9 changes: 5 additions & 4 deletions internal/core/plugin_manager/local_manager/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ func (r *LocalPluginRuntime) InitEnvironment() error {
return nil
}

// execute init command
handle := exec.Command("bash", r.Config.Execution.Install)
// execute init command, create
// TODO
handle := exec.Command("bash")
handle.Dir = r.State.AbsolutePath
handle.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}

Expand Down Expand Up @@ -125,6 +126,6 @@ func (r *LocalPluginRuntime) InitEnvironment() error {
return nil
}

func (r *LocalPluginRuntime) Identity() (plugin_entities.PluginIdentity, error) {
return plugin_entities.PluginIdentity(fmt.Sprintf("%s@%s", r.Config.Identity(), r.Checksum())), nil
func (r *LocalPluginRuntime) Identity() (plugin_entities.PluginUniqueIdentifier, error) {
return plugin_entities.PluginUniqueIdentifier(fmt.Sprintf("%s@%s", r.Config.Identity(), r.Checksum())), nil
}
19 changes: 19 additions & 0 deletions internal/core/plugin_manager/local_manager/environment_python.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package local_manager

import "os/exec"

func (p *LocalPluginRuntime) InitPythonEnvironment(requirements_txt string) error {
// create virtual env
identity, err := p.Identity()
if err != nil {
return err
}

cmd := exec.Command("python", "-m", "venv", identity.String())

// set working directory
cmd.Dir = p.WorkingPath

// TODO
return nil
}
5 changes: 4 additions & 1 deletion internal/core/plugin_manager/local_manager/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ func (r *LocalPluginRuntime) StartPlugin() error {

r.init()
// start plugin
e := exec.Command("bash", r.Config.Execution.Launch)
// TODO: use exec.Command("bash") instead of exec.Command("bash", r.Config.Execution.Launch)
e := exec.Command("bash")
e.Dir = r.State.AbsolutePath
// add env INSTALL_METHOD=local
e.Env = append(e.Env, "INSTALL_METHOD=local", "PATH="+os.Getenv("PATH"))

// NOTE: subprocess will be taken care of by subprocess manager
// ensure all subprocess are killed when parent process exits, especially on Golang debugger
process.WrapProcess(e)

// get writer
Expand Down
22 changes: 11 additions & 11 deletions internal/core/plugin_manager/local_manager/stdio_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ var (
)

type stdioHolder struct {
id string
plugin_identity string
writer io.WriteCloser
reader io.ReadCloser
err_reader io.ReadCloser
l *sync.Mutex
listener map[string]func([]byte)
error_listener map[string]func([]byte)
started bool
id string
plugin_unique_identifier string
writer io.WriteCloser
reader io.ReadCloser
err_reader io.ReadCloser
l *sync.Mutex
listener map[string]func([]byte)
error_listener map[string]func([]byte)
started bool

err_message string
last_err_message_updated_at time.Time
Expand Down Expand Up @@ -94,7 +94,7 @@ func (s *stdioHolder) StartStdout() {
continue
}

log.Info("plugin %s: %s", s.plugin_identity, logEvent.Message)
log.Info("plugin %s: %s", s.plugin_unique_identifier, logEvent.Message)
}
case plugin_entities.PLUGIN_EVENT_SESSION:
for _, listener := range listeners {
Expand All @@ -107,7 +107,7 @@ func (s *stdioHolder) StartStdout() {
}
}
case plugin_entities.PLUGIN_EVENT_ERROR:
log.Error("plugin %s: %s", s.plugin_identity, event.Data)
log.Error("plugin %s: %s", s.plugin_unique_identifier, event.Data)
case plugin_entities.PLUGIN_EVENT_HEARTBEAT:
s.last_active_at = time.Now()
}
Expand Down
14 changes: 7 additions & 7 deletions internal/core/plugin_manager/local_manager/stdio_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ import (
)

func PutStdioIo(
plugin_identity string, writer io.WriteCloser,
plugin_unique_identifier string, writer io.WriteCloser,
reader io.ReadCloser, err_reader io.ReadCloser,
) *stdioHolder {
id := uuid.New().String()

holder := &stdioHolder{
plugin_identity: plugin_identity,
writer: writer,
reader: reader,
err_reader: err_reader,
id: id,
l: &sync.Mutex{},
plugin_unique_identifier: plugin_unique_identifier,
writer: writer,
reader: reader,
err_reader: err_reader,
id: id,
l: &sync.Mutex{},

health_chan_lock: &sync.Mutex{},
health_chan: make(chan bool),
Expand Down
4 changes: 2 additions & 2 deletions internal/core/plugin_manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ func (p *PluginManager) List() []plugin_entities.PluginRuntimeInterface {
return runtimes
}

func (p *PluginManager) Get(identity string) plugin_entities.PluginRuntimeInterface {
if v, ok := p.m.Load(identity); ok {
func (p *PluginManager) Get(identity plugin_entities.PluginUniqueIdentifier) plugin_entities.PluginRuntimeInterface {
if v, ok := p.m.Load(identity.String()); ok {
if r, ok := v.(plugin_entities.PluginRuntimeInterface); ok {
return r
}
Expand Down
4 changes: 2 additions & 2 deletions internal/core/plugin_manager/remote_manager/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
)

func (r *RemotePluginRuntime) Identity() (plugin_entities.PluginIdentity, error) {
func (r *RemotePluginRuntime) Identity() (plugin_entities.PluginUniqueIdentifier, error) {
identity := strings.Join([]string{r.Configuration().Identity(), r.tenant_id}, ":")
return plugin_entities.PluginIdentity(fmt.Sprintf("%s@%s", identity, r.Checksum())), nil
return plugin_entities.PluginUniqueIdentifier(fmt.Sprintf("%s@%s", identity, r.Checksum())), nil
}

func (r *RemotePluginRuntime) Cleanup() {
Expand Down
4 changes: 0 additions & 4 deletions internal/core/plugin_manager/remote_manager/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,6 @@ func TestAcceptConnection(t *testing.T) {
Plugins: []string{
"test",
},
Execution: plugin_entities.PluginExecution{
Install: "echo 'hello'",
Launch: "echo 'hello'",
},
Meta: plugin_entities.PluginMeta{
Version: "0.0.1",
Arch: []constants.Arch{
Expand Down
32 changes: 16 additions & 16 deletions internal/core/session_manager/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ type Session struct {
runtime plugin_entities.PluginRuntimeInterface `json:"-"`
persistence *persistence.Persistence `json:"-"`

TenantID string `json:"tenant_id"`
UserID string `json:"user_id"`
PluginIdentity plugin_entities.PluginIdentity `json:"plugin_identity"`
ClusterID string `json:"cluster_id"`
InvokeFrom access_types.PluginAccessType `json:"invoke_from"`
Action access_types.PluginAccessAction `json:"action"`
Declaration *plugin_entities.PluginDeclaration `json:"declaration"`
TenantID string `json:"tenant_id"`
UserID string `json:"user_id"`
PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifier"`
ClusterID string `json:"cluster_id"`
InvokeFrom access_types.PluginAccessType `json:"invoke_from"`
Action access_types.PluginAccessAction `json:"action"`
Declaration *plugin_entities.PluginDeclaration `json:"declaration"`
}

func sessionKey(id string) string {
Expand All @@ -42,21 +42,21 @@ func sessionKey(id string) string {
func NewSession(
tenant_id string,
user_id string,
plugin_identity plugin_entities.PluginIdentity,
plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
cluster_id string,
invoke_from access_types.PluginAccessType,
action access_types.PluginAccessAction,
declaration *plugin_entities.PluginDeclaration,
) *Session {
s := &Session{
ID: uuid.New().String(),
TenantID: tenant_id,
UserID: user_id,
PluginIdentity: plugin_identity,
ClusterID: cluster_id,
InvokeFrom: invoke_from,
Action: action,
Declaration: declaration,
ID: uuid.New().String(),
TenantID: tenant_id,
UserID: user_id,
PluginUniqueIdentifier: plugin_unique_identifier,
ClusterID: cluster_id,
InvokeFrom: invoke_from,
Action: action,
Declaration: declaration,
}

session_lock.Lock()
Expand Down
20 changes: 10 additions & 10 deletions internal/server/controllers/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ import (

func SetupEndpoint(ctx *gin.Context) {
BindRequest(ctx, func(request struct {
PluginIdentity string `json:"plugin_identity" binding:"required"`
TenantID string `json:"tenant_id" binding:"required"`
UserID string `json:"user_id" binding:"required"`
Settings map[string]any `json:"settings" binding:"omitempty"`
PluginUniqueIdentifier string `json:"plugin_unique_identifier" binding:"required"`
TenantID string `json:"tenant_id" binding:"required"`
UserID string `json:"user_id" binding:"required"`
Settings map[string]any `json:"settings" binding:"omitempty"`
}) {
plugin_identity := request.PluginIdentity
plugin_unique_identifier := request.PluginUniqueIdentifier
tenant_id := request.TenantID
user_id := request.UserID
settings := request.Settings

ctx.JSON(200, service.SetupEndpoint(
tenant_id, user_id, plugin_entities.PluginIdentity(plugin_identity), settings,
tenant_id, user_id, plugin_entities.PluginUniqueIdentifier(plugin_unique_identifier), settings,
))
})
}
Expand All @@ -40,12 +40,12 @@ func ListEndpoints(ctx *gin.Context) {

func RemoveEndpoint(ctx *gin.Context) {
BindRequest(ctx, func(request struct {
PluginIdentity string `json:"plugin_identity"`
TenantID string `json:"tenant_id"`
PluginUniqueIdentifier string `json:"plugin_unique_identifier"`
TenantID string `json:"tenant_id"`
}) {
plugin_identity := request.PluginIdentity
plugin_unique_identifier := request.PluginUniqueIdentifier
tenant_id := request.TenantID

ctx.JSON(200, service.RemoveEndpoint(plugin_identity, tenant_id))
ctx.JSON(200, service.RemoveEndpoint(plugin_unique_identifier, tenant_id))
})
}
9 changes: 7 additions & 2 deletions internal/server/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/langgenius/dify-plugin-daemon/internal/db"
"github.com/langgenius/dify-plugin-daemon/internal/service"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
"github.com/langgenius/dify-plugin-daemon/internal/types/models"
"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
)
Expand Down Expand Up @@ -54,8 +55,12 @@ func (app *App) EndpointHandler(ctx *gin.Context, hook_id string, path string) {
}

// check if plugin exists in current node
if !app.cluster.IsPluginNoCurrentNode(plugin_installation.PluginIdentity) {
app.redirectPluginInvokeByPluginID(ctx, plugin_installation.PluginIdentity)
if !app.cluster.IsPluginNoCurrentNode(
plugin_entities.PluginUniqueIdentifier(plugin_installation.PluginUniqueIdentifier),
) {
app.redirectPluginInvokeByPluginID(ctx, plugin_entities.PluginUniqueIdentifier(
plugin_installation.PluginUniqueIdentifier,
))
} else {
service.Endpoint(ctx, &endpoint, &plugin_installation, path)
}
Expand Down
12 changes: 5 additions & 7 deletions internal/server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,28 @@ func (app *App) RedirectPluginInvoke() gin.HandlerFunc {
reader: bytes.NewReader(raw),
}

identity, err := parser.UnmarshalJsonBytes[plugin_entities.InvokePluginPluginIdentity](raw)
identity, err := parser.UnmarshalJsonBytes[plugin_entities.BasePluginIdentifier](raw)

if err != nil {
ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request"})
return
}

plugin_id := parser.MarshalPluginIdentity(identity.PluginName, identity.PluginVersion)

// check if plugin in current node
if !app.cluster.IsPluginNoCurrentNode(
plugin_id,
identity.PluginUniqueIdentifier,
) {
app.redirectPluginInvokeByPluginID(ctx, plugin_id)
app.redirectPluginInvokeByPluginID(ctx, identity.PluginUniqueIdentifier)
ctx.Abort()
} else {
ctx.Next()
}
}
}

func (app *App) redirectPluginInvokeByPluginID(ctx *gin.Context, plugin_id string) {
func (app *App) redirectPluginInvokeByPluginID(ctx *gin.Context, plugin_id plugin_entities.PluginUniqueIdentifier) {
// try find the correct node
nodes, err := app.cluster.FetchPluginAvailableNodesById(plugin_id)
nodes, err := app.cluster.FetchPluginAvailableNodesById(plugin_id.PluginID())
if err != nil {
ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error"})
log.Error("fetch plugin available nodes failed: %s", err.Error())
Expand Down
6 changes: 4 additions & 2 deletions internal/service/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ func Endpoint(

// fetch plugin
manager := plugin_manager.GetGlobalPluginManager()
runtime := manager.Get(plugin_installation.PluginIdentity)
runtime := manager.Get(
plugin_entities.PluginUniqueIdentifier(plugin_installation.PluginUniqueIdentifier),
)
if runtime == nil {
ctx.JSON(404, gin.H{"error": "plugin not found"})
return
Expand Down Expand Up @@ -74,7 +76,7 @@ func Endpoint(
session := session_manager.NewSession(
endpoint.TenantID,
"",
plugin_entities.PluginIdentity(plugin_installation.PluginIdentity),
plugin_entities.PluginUniqueIdentifier(plugin_installation.PluginUniqueIdentifier),
ctx.GetString("cluster_id"),
access_types.PLUGIN_ACCESS_TYPE_ENDPOINT,
access_types.PLUGIN_ACCESS_ACTION_INVOKE_ENDPOINT,
Expand Down
Loading

0 comments on commit 3603a9c

Please sign in to comment.