Skip to content

Commit

Permalink
refactor: local plugin cwd
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Sep 4, 2024
1 parent 1697a0e commit 277ce20
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 81 deletions.
8 changes: 8 additions & 0 deletions internal/core/plugin_manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
"github.com/langgenius/dify-plugin-daemon/internal/types/app"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
"github.com/langgenius/dify-plugin-daemon/internal/utils/lock"
"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
"github.com/langgenius/dify-plugin-daemon/internal/utils/mapping"
)

type PluginManager struct {
Expand All @@ -19,6 +21,11 @@ type PluginManager struct {

maxPluginPackageSize int64
workingDirectory string

// running plugin in storage contains relations between plugin packages and their running instances
runningPluginInStorage mapping.Map[string, string]
// start process lock
startProcessLock *lock.HighGranularityLock
}

var (
Expand All @@ -30,6 +37,7 @@ func InitGlobalPluginManager(cluster *cluster.Cluster, configuration *app.Config
cluster: cluster,
maxPluginPackageSize: configuration.MaxPluginPackageSize,
workingDirectory: configuration.PluginWorkingPath,
startProcessLock: lock.NewHighGranularityLock(),
}
manager.Init(configuration)
}
Expand Down
1 change: 0 additions & 1 deletion internal/core/plugin_manager/remote_manager/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ func (r *RemotePluginServer) Stop() error {
// Launch starts the server
func (r *RemotePluginServer) Launch() error {
// kill the process if port is already in use
// TODO: switch to optional
exec.Command("fuser", "-k", "tcp", fmt.Sprintf("%d", r.server.port)).Run()

time.Sleep(time.Millisecond * 100)
Expand Down
60 changes: 40 additions & 20 deletions internal/core/plugin_manager/watcher.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package plugin_manager

import (
"errors"
"fmt"
"io"
"os"
Expand All @@ -11,6 +12,7 @@ import (
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/local_manager"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/positive_manager"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/remote_manager"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/checksum"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/verifier"
"github.com/langgenius/dify-plugin-daemon/internal/types/app"
Expand Down Expand Up @@ -75,7 +77,23 @@ func (p *PluginManager) handleNewPlugins(config *app.Config) {
continue
}

identity, err := plugin_interface.Identity()
if err != nil {
log.Error("get plugin identity error: %v", err)
continue
}

// store the plugin in the storage, avoid duplicate loading
p.runningPluginInStorage.Store(plugin.Runtime.State.AbsolutePath, identity.String())

routine.Submit(func() {
defer func() {
if r := recover(); r != nil {
log.Error("plugin runtime error: %v", r)
}
}()
// delete the plugin from the storage when the plugin is stopped
defer p.runningPluginInStorage.Delete(plugin.Runtime.State.AbsolutePath)
p.lifetime(plugin_interface)
})
}
Expand All @@ -100,16 +118,20 @@ func (p *PluginManager) loadNewPlugins(root_path string) <-chan *pluginRuntimeWi
routine.Submit(func() {
for _, plugin := range plugins {
if !plugin.IsDir() {
plugin, err := p.loadPlugin(path.Join(root_path, plugin.Name()))
abs_path := path.Join(root_path, plugin.Name())
if _, ok := p.runningPluginInStorage.Load(abs_path); ok {
// if the plugin is already running, skip it
continue
}

plugin, err := p.loadPlugin(abs_path)
if err != nil {
log.Error("load plugin error: %v", err)
continue
}

ch <- plugin
}
}

close(ch)
})

Expand All @@ -119,50 +141,49 @@ func (p *PluginManager) loadNewPlugins(root_path string) <-chan *pluginRuntimeWi
func (p *PluginManager) loadPlugin(plugin_path string) (*pluginRuntimeWithDecoder, error) {
pack, err := os.Open(plugin_path)
if err != nil {
log.Error("open plugin package error: %v", err)
return nil, err
return nil, errors.Join(err, fmt.Errorf("open plugin package error"))
}
defer pack.Close()

if info, err := pack.Stat(); err != nil {
log.Error("get plugin package info error: %v", err)
return nil, err
return nil, errors.Join(err, fmt.Errorf("get plugin package info error"))
} else if info.Size() > p.maxPluginPackageSize {
log.Error("plugin package size is too large: %d", info.Size())
return nil, err
}

plugin_zip, err := io.ReadAll(pack)
if err != nil {
log.Error("read plugin package error: %v", err)
return nil, err
return nil, errors.Join(err, fmt.Errorf("read plugin package error"))
}

decoder, err := decoder.NewZipPluginDecoder(plugin_zip)
if err != nil {
log.Error("create plugin decoder error: %v", err)
return nil, err
return nil, errors.Join(err, fmt.Errorf("create plugin decoder error"))
}

// get manifest
manifest, err := decoder.Manifest()
if err != nil {
log.Error("get plugin manifest error: %v", err)
return nil, err
return nil, errors.Join(err, fmt.Errorf("get plugin manifest error"))
}

// check if already exists
if _, exist := p.m.Load(manifest.Identity()); exist {
log.Warn("plugin already exists: %s", manifest.Identity())
return nil, fmt.Errorf("plugin already exists: %s", manifest.Identity())
return nil, errors.Join(fmt.Errorf("plugin already exists: %s", manifest.Identity()), err)
}

plugin_working_path := path.Join(p.workingDirectory, manifest.Identity())
// TODO: use plugin unique id as the working directory
checksum, err := checksum.CalculateChecksum(decoder)
if err != nil {
return nil, errors.Join(err, fmt.Errorf("calculate checksum error"))
}

plugin_working_path := path.Join(p.workingDirectory, fmt.Sprintf("%s@%s", manifest.Identity(), checksum))

// check if working directory exists
if _, err := os.Stat(plugin_working_path); err == nil {
log.Warn("plugin working directory already exists: %s", plugin_working_path)
return nil, fmt.Errorf("plugin working directory already exists: %s", plugin_working_path)
return nil, errors.Join(fmt.Errorf("plugin working directory already exists: %s", plugin_working_path), err)
}

// copy to working directory
Expand All @@ -187,8 +208,7 @@ func (p *PluginManager) loadPlugin(plugin_path string) (*pluginRuntimeWithDecode

return nil
}); err != nil {
log.Error("copy plugin to working directory error: %v", err)
return nil, err
return nil, errors.Join(fmt.Errorf("copy plugin to working directory error: %v", err), err)
}

return &pluginRuntimeWithDecoder{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func isPluginName(fl validator.FieldLevel) bool {
}

func (p *PluginDeclaration) Identity() string {
return parser.MarshalPluginUniqueIdentifier(p.Name, p.Version)
return parser.MarshalPluginID(p.Name, p.Version)
}

func (p *PluginDeclaration) ManifestValidate() error {
Expand Down
58 changes: 0 additions & 58 deletions internal/utils/http_requests/http_warpper.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,64 +131,6 @@ func RequestAndParseStream[T any](client *http.Client, url string, method string
return ch, nil
}

// TODO: improve this, deduplicate code
func RequestAndParseStreamMap(client *http.Client, url string, method string, options ...HttpOptions) (*stream.StreamResponse[map[string]any], error) {
resp, err := Request(client, url, method, options...)
if err != nil {
return nil, err
}

if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
error_text, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("request failed with status code: %d and respond with: %s", resp.StatusCode, error_text)
}

ch := stream.NewStreamResponse[map[string]any](1024)

// get read timeout
read_timeout := int64(60000)
for _, option := range options {
if option.Type == "read_timeout" {
read_timeout = option.Value.(int64)
break
}
}
time.AfterFunc(time.Millisecond*time.Duration(read_timeout), func() {
// close the response body if timeout
resp.Body.Close()
})

routine.Submit(func() {
scanner := bufio.NewScanner(resp.Body)
defer resp.Body.Close()

for scanner.Scan() {
data := scanner.Bytes()
if len(data) == 0 {
continue
}

if bytes.HasPrefix(data, []byte("data: ")) {
// split
data = data[6:]
}

// unmarshal
t, err := parser.UnmarshalJsonBytes2Map(data)
if err != nil {
continue
}

ch.Write(t)
}

ch.Close()
})

return ch, nil
}

func GetAndParseStream[T any](client *http.Client, url string, options ...HttpOptions) (*stream.StreamResponse[T], error) {
return RequestAndParseStream[T](client, url, "GET", options...)
}
Expand Down
51 changes: 51 additions & 0 deletions internal/utils/lock/lock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package lock

import (
"sync"
"sync/atomic"
)

type mutex struct {
*sync.Mutex
count int32
}

type HighGranularityLock struct {
m map[string]*mutex
l sync.Mutex
}

func NewHighGranularityLock() *HighGranularityLock {
return &HighGranularityLock{
m: make(map[string]*mutex),
}
}

func (l *HighGranularityLock) Lock(key string) {
l.l.Lock()
var m *mutex
var ok bool
if m, ok = l.m[key]; !ok {
m = &mutex{Mutex: &sync.Mutex{}, count: 1}
l.m[key] = m
} else {
atomic.AddInt32(&m.count, 1)
}
l.l.Unlock()

m.Lock()
}

func (l *HighGranularityLock) Unlock(key string) {
l.l.Lock()
m, ok := l.m[key]
if !ok {
return
}
atomic.AddInt32(&m.count, -1)
if atomic.LoadInt32(&m.count) == 0 {
delete(l.m, key)
}
l.l.Unlock()
m.Unlock()
}
40 changes: 40 additions & 0 deletions internal/utils/lock/lock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package lock

import (
"fmt"
"sync"
"testing"
)

func TestHighGranularityLock(t *testing.T) {
l := NewHighGranularityLock()

data := []int{}
add := func(key int) {
l.Lock(fmt.Sprintf("%d", key))
data[key]++
l.Unlock(fmt.Sprintf("%d", key))
}

for i := 0; i < 1000; i++ {
data = append(data, 0)
}

wg := sync.WaitGroup{}
for i := 0; i < 1000; i++ {
wg.Add(1)
go func() {
for j := 0; j < 1000; j++ {
add(j)
}
wg.Done()
}()
}
wg.Wait()

for _, v := range data {
if v != 1000 {
t.Fatal("data not equal")
}
}
}
2 changes: 1 addition & 1 deletion internal/utils/parser/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ package parser

import "fmt"

func MarshalPluginUniqueIdentifier(name string, version string) string {
func MarshalPluginID(name string, version string) string {
return fmt.Sprintf("%s:%s", name, version)
}

0 comments on commit 277ce20

Please sign in to comment.