From 5b96e61a9d6a102779e01a1a60111f475ad1e817 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 19 Jul 2024 22:20:12 +0800 Subject: [PATCH] feat: generic invocation --- internal/core/plugin_daemon/basic.go | 9 +- internal/core/plugin_daemon/model_service.go | 151 ++++++++++++++---- internal/core/plugin_daemon/tool_service.go | 77 +-------- .../core/plugin_manager/stdio_holder/io.go | 10 +- internal/service/invoke.go | 3 +- .../types/entities/plugin_entities/event.go | 4 - internal/types/entities/requests/model.go | 33 ++-- internal/utils/parser/struct2map.go | 48 ++++++ 8 files changed, 217 insertions(+), 118 deletions(-) create mode 100644 internal/utils/parser/struct2map.go diff --git a/internal/core/plugin_daemon/basic.go b/internal/core/plugin_daemon/basic.go index 565a3f2..18a495f 100644 --- a/internal/core/plugin_daemon/basic.go +++ b/internal/core/plugin_daemon/basic.go @@ -10,8 +10,13 @@ const ( type PluginAccessAction string const ( - PLUGIN_ACCESS_ACTION_INVOKE_TOOL PluginAccessAction = "invoke_tool" - PLUGIN_ACCESS_ACTION_INVOKE_LLM PluginAccessAction = "invoke_llm" + PLUGIN_ACCESS_ACTION_INVOKE_TOOL PluginAccessAction = "invoke_tool" + PLUGIN_ACCESS_ACTION_INVOKE_LLM PluginAccessAction = "invoke_llm" + PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING PluginAccessAction = "invoke_text_embedding" + PLUGIN_ACCESS_ACTION_INVOKE_RERANK PluginAccessAction = "invoke_rerank" + PLUGIN_ACCESS_ACTION_INVOKE_TTS PluginAccessAction = "invoke_tts" + PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT PluginAccessAction = "invoke_speech2text" + PLUGIN_ACCESS_ACTION_INVOKE_MODERATION PluginAccessAction = "invoke_moderation" ) const ( diff --git a/internal/core/plugin_daemon/model_service.go b/internal/core/plugin_daemon/model_service.go index 965296b..b575ac8 100644 --- a/internal/core/plugin_daemon/model_service.go +++ b/internal/core/plugin_daemon/model_service.go @@ -5,6 +5,7 @@ import ( "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager" "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests" "github.com/langgenius/dify-plugin-daemon/internal/utils/log" @@ -12,39 +13,21 @@ import ( "github.com/langgenius/dify-plugin-daemon/internal/utils/stream" ) -func getInvokeModelMap( +func genericInvokePlugin[Req any, Rsp any]( session *session_manager.Session, + request *Req, + response_buffer_size int, + typ PluginAccessType, action PluginAccessAction, - request *requests.RequestInvokeLLM, -) map[string]any { - req := getBasicPluginAccessMap(session.ID(), session.UserID(), PLUGIN_ACCESS_TYPE_MODEL, action) - data := req["data"].(map[string]any) - - data["provider"] = request.Provider - data["model"] = request.Model - data["model_type"] = request.ModelType - data["model_parameters"] = request.ModelParameters - data["prompt_messages"] = request.PromptMessages - data["tools"] = request.Tools - data["stop"] = request.Stop - data["stream"] = request.Stream - data["credentials"] = request.Credentials - - return req -} - -func InvokeLLM( - session *session_manager.Session, - request *requests.RequestInvokeLLM, ) ( - *stream.StreamResponse[plugin_entities.InvokeModelResponseChunk], error, + *stream.StreamResponse[Rsp], error, ) { runtime := plugin_manager.Get(session.PluginIdentity()) if runtime == nil { return nil, errors.New("plugin not found") } - response := stream.NewStreamResponse[plugin_entities.InvokeModelResponseChunk](512) + response := stream.NewStreamResponse[Rsp](response_buffer_size) listener := runtime.Listen(session.ID()) listener.AddListener(func(message []byte) { @@ -56,7 +39,7 @@ func InvokeLLM( switch chunk.Type { case plugin_entities.SESSION_MESSAGE_TYPE_STREAM: - chunk, err := parser.UnmarshalJsonBytes[plugin_entities.InvokeModelResponseChunk](chunk.Data) + chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data) if err != nil { log.Error("unmarshal json failed: %s", err.Error()) return @@ -66,8 +49,15 @@ func InvokeLLM( invokeDify(runtime, session, chunk.Data) case plugin_entities.SESSION_MESSAGE_TYPE_END: response.Close() + case plugin_entities.SESSION_MESSAGE_TYPE_ERROR: + e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data) + if err != nil { + break + } + response.WriteError(errors.New(e.Error)) + response.Close() default: - log.Error("unknown stream message type: %s", chunk.Type) + response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type))) response.Close() } }) @@ -79,10 +69,117 @@ func InvokeLLM( runtime.Write(session.ID(), []byte(parser.MarshalJson( getInvokeModelMap( session, - PLUGIN_ACCESS_ACTION_INVOKE_LLM, + typ, + action, request, ), ))) return response, nil } + +func getInvokeModelMap( + session *session_manager.Session, + typ PluginAccessType, + action PluginAccessAction, + request any, +) map[string]any { + req := getBasicPluginAccessMap(session.ID(), session.UserID(), typ, action) + data := req["data"].(map[string]any) + + for k, v := range parser.StructToMap(request) { + data[k] = v + } + + return req +} + +func InvokeLLM( + session *session_manager.Session, + request *requests.RequestInvokeLLM, +) ( + *stream.StreamResponse[model_entities.LLMResultChunk], error, +) { + return genericInvokePlugin[requests.RequestInvokeLLM, model_entities.LLMResultChunk]( + session, + request, + 512, + PLUGIN_ACCESS_TYPE_MODEL, + PLUGIN_ACCESS_ACTION_INVOKE_LLM, + ) +} + +func InvokeTextEmbedding( + session *session_manager.Session, + request *requests.RequestInvokeTextEmbedding, +) ( + *stream.StreamResponse[model_entities.TextEmbeddingResult], error, +) { + return genericInvokePlugin[requests.RequestInvokeTextEmbedding, model_entities.TextEmbeddingResult]( + session, + request, + 1, + PLUGIN_ACCESS_TYPE_MODEL, + PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING, + ) +} + +func InvokeRerank( + session *session_manager.Session, + request *requests.RequestInvokeRerank, +) ( + *stream.StreamResponse[model_entities.RerankResult], error, +) { + return genericInvokePlugin[requests.RequestInvokeRerank, model_entities.RerankResult]( + session, + request, + 1, + PLUGIN_ACCESS_TYPE_MODEL, + PLUGIN_ACCESS_ACTION_INVOKE_RERANK, + ) +} + +func InvokeTTS( + session *session_manager.Session, + request *requests.RequestInvokeTTS, +) ( + *stream.StreamResponse[string], error, +) { + return genericInvokePlugin[requests.RequestInvokeTTS, string]( + session, + request, + 1, + PLUGIN_ACCESS_TYPE_MODEL, + PLUGIN_ACCESS_ACTION_INVOKE_TTS, + ) +} + +func InvokeSpeech2Text( + session *session_manager.Session, + request *requests.RequestInvokeSpeech2Text, +) ( + *stream.StreamResponse[string], error, +) { + return genericInvokePlugin[requests.RequestInvokeSpeech2Text, string]( + session, + request, + 1, + PLUGIN_ACCESS_TYPE_MODEL, + PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT, + ) +} + +func InvokeModeration( + session *session_manager.Session, + request *requests.RequestInvokeModeration, +) ( + *stream.StreamResponse[bool], error, +) { + return genericInvokePlugin[requests.RequestInvokeModeration, bool]( + session, + request, + 1, + PLUGIN_ACCESS_TYPE_MODEL, + PLUGIN_ACCESS_ACTION_INVOKE_MODERATION, + ) +} diff --git a/internal/core/plugin_daemon/tool_service.go b/internal/core/plugin_daemon/tool_service.go index 5bdee61..cfc92db 100644 --- a/internal/core/plugin_daemon/tool_service.go +++ b/internal/core/plugin_daemon/tool_service.go @@ -1,86 +1,23 @@ package plugin_daemon import ( - "errors" - - "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager" "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests" - "github.com/langgenius/dify-plugin-daemon/internal/utils/log" - "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" "github.com/langgenius/dify-plugin-daemon/internal/utils/stream" ) -func getInvokeToolMap( - session *session_manager.Session, - action PluginAccessAction, - request *requests.RequestInvokeTool, -) map[string]any { - req := getBasicPluginAccessMap(session.ID(), session.UserID(), PLUGIN_ACCESS_TYPE_TOOL, action) - data := req["data"].(map[string]any) - - data["provider"] = request.Provider - data["tool"] = request.Tool - data["parameters"] = request.ToolParameters - data["credentials"] = request.Credentials - - return req -} - func InvokeTool( session *session_manager.Session, request *requests.RequestInvokeTool, ) ( *stream.StreamResponse[plugin_entities.ToolResponseChunk], error, ) { - runtime := plugin_manager.Get(session.PluginIdentity()) - if runtime == nil { - return nil, errors.New("plugin not found") - } - - response := stream.NewStreamResponse[plugin_entities.ToolResponseChunk](512) - - listener := runtime.Listen(session.ID()) - listener.AddListener(func(message []byte) { - chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message) - if err != nil { - log.Error("unmarshal json failed: %s", err.Error()) - return - } - - switch chunk.Type { - case plugin_entities.SESSION_MESSAGE_TYPE_STREAM: - chunk, err := parser.UnmarshalJsonBytes[plugin_entities.ToolResponseChunk](chunk.Data) - if err != nil { - log.Error("unmarshal json failed: %s", err.Error()) - return - } - response.Write(chunk) - case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE: - invokeDify(runtime, session, chunk.Data) - case plugin_entities.SESSION_MESSAGE_TYPE_END: - response.Close() - case plugin_entities.SESSION_MESSAGE_TYPE_ERROR: - e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data) - if err != nil { - break - } - response.WriteError(errors.New(e.Error)) - response.Close() - default: - response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type))) - response.Close() - } - }) - - response.OnClose(func() { - listener.Close() - }) - - runtime.Write(session.ID(), []byte(parser.MarshalJson( - getInvokeToolMap(session, PLUGIN_ACCESS_ACTION_INVOKE_TOOL, request)), - )) - - return response, nil + return genericInvokePlugin[requests.RequestInvokeTool, plugin_entities.ToolResponseChunk]( + session, + request, + 128, + PLUGIN_ACCESS_TYPE_TOOL, + PLUGIN_ACCESS_ACTION_INVOKE_TOOL, + ) } diff --git a/internal/core/plugin_manager/stdio_holder/io.go b/internal/core/plugin_manager/stdio_holder/io.go index d89cb11..f715e19 100644 --- a/internal/core/plugin_manager/stdio_holder/io.go +++ b/internal/core/plugin_manager/stdio_holder/io.go @@ -71,6 +71,10 @@ func (s *stdioHolder) StartStdout() { scanner := bufio.NewScanner(s.reader) for scanner.Scan() { data := scanner.Bytes() + if len(data) == 0 { + continue + } + event, err := parser.UnmarshalJsonBytes[plugin_entities.PluginUniversalEvent](data) if err != nil { // log.Error("unmarshal json failed: %s", err.Error()) @@ -101,11 +105,7 @@ func (s *stdioHolder) StartStdout() { } } case plugin_entities.PLUGIN_EVENT_ERROR: - for listener_session_id, listener := range s.error_listener { - if listener_session_id == session_id { - listener(event.Data) - } - } + log.Error("plugin %s: %s", s.plugin_identity, event.Data) case plugin_entities.PLUGIN_EVENT_HEARTBEAT: s.last_active_at = time.Now() } diff --git a/internal/service/invoke.go b/internal/service/invoke.go index f89b92a..0a31dc2 100644 --- a/internal/service/invoke.go +++ b/internal/service/invoke.go @@ -4,6 +4,7 @@ import ( "github.com/gin-gonic/gin" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon" "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests" "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" @@ -25,7 +26,7 @@ func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM] session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) defer session.Close() - baseSSEService(r, func() (*stream.StreamResponse[plugin_entities.InvokeModelResponseChunk], error) { + baseSSEService(r, func() (*stream.StreamResponse[model_entities.LLMResultChunk], error) { return plugin_daemon.InvokeLLM(session, &r.Data) }, ctx) } diff --git a/internal/types/entities/plugin_entities/event.go b/internal/types/entities/plugin_entities/event.go index ffd4c8b..e88433e 100644 --- a/internal/types/entities/plugin_entities/event.go +++ b/internal/types/entities/plugin_entities/event.go @@ -2,8 +2,6 @@ package plugin_entities import ( "encoding/json" - - "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities" ) type PluginUniversalEvent struct { @@ -51,8 +49,6 @@ type PluginResponseChunk struct { Data json.RawMessage `json:"data"` } -type InvokeModelResponseChunk = model_entities.LLMResultChunk - type ErrorResponse struct { Error string `json:"error"` } diff --git a/internal/types/entities/requests/model.go b/internal/types/entities/requests/model.go index a746dde..fae95e6 100644 --- a/internal/types/entities/requests/model.go +++ b/internal/types/entities/requests/model.go @@ -5,20 +5,29 @@ import ( ) type BaseRequestInvokeModel struct { - Provider string `json:"provider"` - ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type"` - Model string `json:"model"` + Provider string `json:"provider" validate:"required"` + ModelType model_entities.ModelType `json:"model_type" mapstructure:"model_type" validate:"required,model_type"` + Model string `json:"model" validate:"required"` Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"` } +func (r *BaseRequestInvokeModel) ToCallerArguments() map[string]any { + return map[string]any{ + "provider": r.Provider, + "model": r.Model, + "model_type": r.ModelType, + "credentials": r.Credentials, + } +} + type RequestInvokeLLM struct { BaseRequestInvokeModel - ModelParameters map[string]any `json:"model_parameters" validate:"omitempty,dive,is_basic_type"` - PromptMessages []model_entities.PromptMessage `json:"prompt_messages" validate:"omitempty,dive"` + ModelParameters map[string]any `json:"model_parameters" mapstructure:"model_parameters" validate:"omitempty,dive,is_basic_type"` + PromptMessages []model_entities.PromptMessage `json:"prompt_messages" mapstructure:"prompt_messages" validate:"omitempty,dive"` Tools []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"` Stop []string `json:"stop" validate:"omitempty"` - Stream bool `json:"stream"` + Stream bool `json:"stream" mapstructure:"stream"` } type RequestInvokeTextEmbedding struct { @@ -32,14 +41,14 @@ type RequestInvokeRerank struct { Query string `json:"query" validate:"required"` Docs []string `json:"docs" validate:"required,dive"` - ScoreThreshold float64 `json:"score_threshold"` - TopN int `json:"top_n"` + ScoreThreshold float64 `json:"score_threshold" mapstructure:"score_threshold"` + TopN int `json:"top_n" mapstructure:"top_n"` } type RequestInvokeTTS struct { BaseRequestInvokeModel - ContentText string `json:"content_text" validate:"required"` + ContentText string `json:"content_text" mapstructure:"content_text" validate:"required"` Voice string `json:"voice" validate:"required"` } @@ -48,3 +57,9 @@ type RequestInvokeSpeech2Text struct { File string `json:"file" validate:"required"` // base64 encoded voice file } + +type RequestInvokeModeration struct { + BaseRequestInvokeModel + + Text string `json:"text" validate:"required"` +} diff --git a/internal/utils/parser/struct2map.go b/internal/utils/parser/struct2map.go new file mode 100644 index 0000000..7405c90 --- /dev/null +++ b/internal/utils/parser/struct2map.go @@ -0,0 +1,48 @@ +package parser + +import ( + "reflect" + "unicode" +) + +func StructToMap(data interface{}) map[string]interface{} { + result := make(map[string]interface{}) + val := reflect.ValueOf(data) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + for i := 0; i < val.NumField(); i++ { + field := val.Field(i) + typeField := val.Type().Field(i) + fieldName := toSnakeCase(typeField.Name) + + if typeField.Anonymous { + embeddedFields := StructToMap(field.Interface()) + for k, v := range embeddedFields { + result[k] = v + } + } else { + result[fieldName] = field.Interface() + } + } + return result +} + +func toSnakeCase(str string) string { + runes := []rune(str) + length := len(runes) + var out []rune + + for i := 0; i < length; i++ { + if unicode.IsUpper(runes[i]) { + if i > 0 { + out = append(out, '_') + } + out = append(out, unicode.ToLower(runes[i])) + } else { + out = append(out, runes[i]) + } + } + + return string(out) +}