From c9dc8aadef7e0c7cff026676da5ba58cd2a6ca81 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 19 Jul 2024 17:45:00 +0800 Subject: [PATCH] feat: invoke tool and models --- internal/core/dify_invocation/http_request.go | 28 +- internal/core/plugin_daemon/basic.go | 31 ++ internal/core/plugin_daemon/daemon.go | 65 ---- internal/core/plugin_daemon/invoke_dify.go | 3 +- internal/core/plugin_daemon/model_service.go | 88 ++++++ internal/core/plugin_daemon/tool_service.go | 86 ++++++ .../core/plugin_manager/aws_manager/run.go | 4 + internal/core/plugin_manager/lifetime.go | 9 + .../local_manager/environment.go | 2 +- .../core/plugin_manager/local_manager/run.go | 55 +++- .../core/plugin_manager/local_manager/type.go | 1 + .../core/plugin_manager/stdio_holder/io.go | 269 ++++++++--------- .../core/plugin_manager/stdio_holder/store.go | 102 +++++++ internal/core/plugin_manager/watcher.go | 8 +- internal/server/controller.go | 14 +- internal/server/http.go | 1 + internal/service/invoke.go | 59 +--- internal/service/runner.go | 80 +++++ internal/types/entities/model_entities/llm.go | 243 +++++++++++++++ .../types/entities/model_entities/llm_test.go | 278 ++++++++++++++++++ .../plugin_entities/basic_type_test.go | 10 +- .../types/entities/plugin_entities/event.go | 26 +- .../plugin_entities/model_configuration.go | 29 +- .../plugin_entities/plugin_declaration.go | 20 +- .../types/entities/plugin_entities/request.go | 17 +- .../plugin_entities/tool_configuration.go | 46 ++- internal/types/entities/requests/model.go | 38 +++ internal/types/entities/requests/tool.go | 32 ++ internal/types/entities/runtime.go | 1 + internal/types/validators/validators.go | 7 + internal/utils/cache/redis.go | 2 +- .../http_options.go | 2 +- .../http_request.go | 2 +- .../http_warpper.go | 2 +- internal/utils/stream/response.go | 26 +- internal/utils/stream/response_test.go | 55 ++++ 36 files changed, 1373 insertions(+), 368 deletions(-) create mode 100644 internal/core/plugin_daemon/basic.go create mode 100644 internal/core/plugin_daemon/model_service.go create mode 100644 internal/core/plugin_daemon/tool_service.go create mode 100644 internal/core/plugin_manager/stdio_holder/store.go create mode 100644 internal/service/runner.go create mode 100644 internal/types/entities/model_entities/llm_test.go create mode 100644 internal/types/entities/requests/model.go create mode 100644 internal/types/entities/requests/tool.go create mode 100644 internal/types/validators/validators.go rename internal/utils/{requests => http_requests}/http_options.go (98%) rename internal/utils/{requests => http_requests}/http_request.go (98%) rename internal/utils/{requests => http_requests}/http_warpper.go (99%) create mode 100644 internal/utils/stream/response_test.go diff --git a/internal/core/dify_invocation/http_request.go b/internal/core/dify_invocation/http_request.go index 1dafdea..b5853aa 100644 --- a/internal/core/dify_invocation/http_request.go +++ b/internal/core/dify_invocation/http_request.go @@ -1,42 +1,42 @@ package dify_invocation import ( - "github.com/langgenius/dify-plugin-daemon/internal/utils/requests" + "github.com/langgenius/dify-plugin-daemon/internal/utils/http_requests" "github.com/langgenius/dify-plugin-daemon/internal/utils/stream" ) -func Request[T any](method string, path string, options ...requests.HttpOptions) (*T, error) { +func Request[T any](method string, path string, options ...http_requests.HttpOptions) (*T, error) { options = append(options, - requests.HttpHeader(map[string]string{ + http_requests.HttpHeader(map[string]string{ "X-Inner-Api-Key": PLUGIN_INNER_API_KEY, }), - requests.HttpWriteTimeout(5000), - requests.HttpReadTimeout(60000), + http_requests.HttpWriteTimeout(5000), + http_requests.HttpReadTimeout(60000), ) - return requests.RequestAndParse[T](client, difyPath(path), method, options...) + return http_requests.RequestAndParse[T](client, difyPath(path), method, options...) } -func StreamResponse[T any](method string, path string, options ...requests.HttpOptions) (*stream.StreamResponse[T], error) { +func StreamResponse[T any](method string, path string, options ...http_requests.HttpOptions) (*stream.StreamResponse[T], error) { options = append( - options, requests.HttpHeader(map[string]string{ + options, http_requests.HttpHeader(map[string]string{ "X-Inner-Api-Key": PLUGIN_INNER_API_KEY, }), - requests.HttpWriteTimeout(5000), - requests.HttpReadTimeout(60000), + http_requests.HttpWriteTimeout(5000), + http_requests.HttpReadTimeout(60000), ) - return requests.RequestAndParseStream[T](client, difyPath(path), method, options...) + return http_requests.RequestAndParseStream[T](client, difyPath(path), method, options...) } func InvokeModel(payload *InvokeModelRequest) (*stream.StreamResponse[InvokeModelResponseChunk], error) { - return StreamResponse[InvokeModelResponseChunk]("POST", "invoke/model", requests.HttpPayloadJson(payload)) + return StreamResponse[InvokeModelResponseChunk]("POST", "invoke/model", http_requests.HttpPayloadJson(payload)) } func InvokeTool(payload *InvokeToolRequest) (*stream.StreamResponse[InvokeToolResponseChunk], error) { - return StreamResponse[InvokeToolResponseChunk]("POST", "invoke/tool", requests.HttpPayloadJson(payload)) + return StreamResponse[InvokeToolResponseChunk]("POST", "invoke/tool", http_requests.HttpPayloadJson(payload)) } func InvokeNode[T WorkflowNodeData](payload *InvokeNodeRequest[T]) (*InvokeNodeResponse, error) { - return Request[InvokeNodeResponse]("POST", "invoke/node", requests.HttpPayloadJson(payload)) + return Request[InvokeNodeResponse]("POST", "invoke/node", http_requests.HttpPayloadJson(payload)) } diff --git a/internal/core/plugin_daemon/basic.go b/internal/core/plugin_daemon/basic.go new file mode 100644 index 0000000..565a3f2 --- /dev/null +++ b/internal/core/plugin_daemon/basic.go @@ -0,0 +1,31 @@ +package plugin_daemon + +type PluginAccessType string + +const ( + PLUGIN_ACCESS_TYPE_TOOL PluginAccessType = "tool" + PLUGIN_ACCESS_TYPE_MODEL PluginAccessType = "model" +) + +type PluginAccessAction string + +const ( + PLUGIN_ACCESS_ACTION_INVOKE_TOOL PluginAccessAction = "invoke_tool" + PLUGIN_ACCESS_ACTION_INVOKE_LLM PluginAccessAction = "invoke_llm" +) + +const ( + PLUGIN_IN_STREAM_EVENT = "request" +) + +func getBasicPluginAccessMap(session_id string, user_id string, access_type PluginAccessType, action PluginAccessAction) map[string]any { + return map[string]any{ + "session_id": session_id, + "event": PLUGIN_IN_STREAM_EVENT, + "data": map[string]any{ + "user_id": user_id, + "type": access_type, + "action": action, + }, + } +} diff --git a/internal/core/plugin_daemon/daemon.go b/internal/core/plugin_daemon/daemon.go index 986982e..2a166e1 100644 --- a/internal/core/plugin_daemon/daemon.go +++ b/internal/core/plugin_daemon/daemon.go @@ -1,66 +1 @@ package plugin_daemon - -import ( - "errors" - - "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager" - "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" - "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" - "github.com/langgenius/dify-plugin-daemon/internal/utils/log" - "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" - "github.com/langgenius/dify-plugin-daemon/internal/utils/stream" -) - -type ToolResponseChunk = plugin_entities.InvokeToolResponseChunk - -func InvokeTool(session *session_manager.Session, provider_name string, tool_name string, tool_parameters map[string]any) ( - *stream.StreamResponse[ToolResponseChunk], error, -) { - runtime := plugin_manager.Get(session.PluginIdentity()) - if runtime == nil { - return nil, errors.New("plugin not found") - } - - response := stream.NewStreamResponse[ToolResponseChunk](512) - - listener := runtime.Listen(session.ID()) - listener.AddListener(func(message []byte) { - chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message) - if err != nil { - log.Error("unmarshal json failed: %s", err.Error()) - return - } - - switch chunk.Type { - case plugin_entities.SESSION_MESSAGE_TYPE_STREAM: - chunk, err := parser.UnmarshalJsonBytes[ToolResponseChunk](chunk.Data) - if err != nil { - log.Error("unmarshal json failed: %s", err.Error()) - return - } - response.Write(chunk) - case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE: - invokeDify(runtime, session, chunk.Data) - case plugin_entities.SESSION_MESSAGE_TYPE_END: - response.Close() - default: - log.Error("unknown stream message type: %s", chunk.Type) - response.Close() - } - }) - - response.OnClose(func() { - listener.Close() - }) - - runtime.Write(session.ID(), []byte(parser.MarshalJson( - map[string]any{ - "provider": provider_name, - "tool": tool_name, - "parameters": tool_parameters, - "session_id": session.ID(), - }, - ))) - - return response, nil -} diff --git a/internal/core/plugin_daemon/invoke_dify.go b/internal/core/plugin_daemon/invoke_dify.go index a0f1cef..81ab586 100644 --- a/internal/core/plugin_daemon/invoke_dify.go +++ b/internal/core/plugin_daemon/invoke_dify.go @@ -11,7 +11,8 @@ import ( "github.com/langgenius/dify-plugin-daemon/internal/utils/routine" ) -func invokeDify(runtime entities.PluginRuntimeInterface, +func invokeDify( + runtime entities.PluginRuntimeInterface, session *session_manager.Session, data []byte, ) error { // unmarshal invoke data diff --git a/internal/core/plugin_daemon/model_service.go b/internal/core/plugin_daemon/model_service.go new file mode 100644 index 0000000..965296b --- /dev/null +++ b/internal/core/plugin_daemon/model_service.go @@ -0,0 +1,88 @@ +package plugin_daemon + +import ( + "errors" + + "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager" + "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests" + "github.com/langgenius/dify-plugin-daemon/internal/utils/log" + "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" + "github.com/langgenius/dify-plugin-daemon/internal/utils/stream" +) + +func getInvokeModelMap( + session *session_manager.Session, + action PluginAccessAction, + request *requests.RequestInvokeLLM, +) map[string]any { + req := getBasicPluginAccessMap(session.ID(), session.UserID(), PLUGIN_ACCESS_TYPE_MODEL, action) + data := req["data"].(map[string]any) + + data["provider"] = request.Provider + data["model"] = request.Model + data["model_type"] = request.ModelType + data["model_parameters"] = request.ModelParameters + data["prompt_messages"] = request.PromptMessages + data["tools"] = request.Tools + data["stop"] = request.Stop + data["stream"] = request.Stream + data["credentials"] = request.Credentials + + return req +} + +func InvokeLLM( + session *session_manager.Session, + request *requests.RequestInvokeLLM, +) ( + *stream.StreamResponse[plugin_entities.InvokeModelResponseChunk], error, +) { + runtime := plugin_manager.Get(session.PluginIdentity()) + if runtime == nil { + return nil, errors.New("plugin not found") + } + + response := stream.NewStreamResponse[plugin_entities.InvokeModelResponseChunk](512) + + listener := runtime.Listen(session.ID()) + listener.AddListener(func(message []byte) { + chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message) + if err != nil { + log.Error("unmarshal json failed: %s", err.Error()) + return + } + + switch chunk.Type { + case plugin_entities.SESSION_MESSAGE_TYPE_STREAM: + chunk, err := parser.UnmarshalJsonBytes[plugin_entities.InvokeModelResponseChunk](chunk.Data) + if err != nil { + log.Error("unmarshal json failed: %s", err.Error()) + return + } + response.Write(chunk) + case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE: + invokeDify(runtime, session, chunk.Data) + case plugin_entities.SESSION_MESSAGE_TYPE_END: + response.Close() + default: + log.Error("unknown stream message type: %s", chunk.Type) + response.Close() + } + }) + + response.OnClose(func() { + listener.Close() + }) + + runtime.Write(session.ID(), []byte(parser.MarshalJson( + getInvokeModelMap( + session, + PLUGIN_ACCESS_ACTION_INVOKE_LLM, + request, + ), + ))) + + return response, nil +} diff --git a/internal/core/plugin_daemon/tool_service.go b/internal/core/plugin_daemon/tool_service.go new file mode 100644 index 0000000..5bdee61 --- /dev/null +++ b/internal/core/plugin_daemon/tool_service.go @@ -0,0 +1,86 @@ +package plugin_daemon + +import ( + "errors" + + "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager" + "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests" + "github.com/langgenius/dify-plugin-daemon/internal/utils/log" + "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" + "github.com/langgenius/dify-plugin-daemon/internal/utils/stream" +) + +func getInvokeToolMap( + session *session_manager.Session, + action PluginAccessAction, + request *requests.RequestInvokeTool, +) map[string]any { + req := getBasicPluginAccessMap(session.ID(), session.UserID(), PLUGIN_ACCESS_TYPE_TOOL, action) + data := req["data"].(map[string]any) + + data["provider"] = request.Provider + data["tool"] = request.Tool + data["parameters"] = request.ToolParameters + data["credentials"] = request.Credentials + + return req +} + +func InvokeTool( + session *session_manager.Session, + request *requests.RequestInvokeTool, +) ( + *stream.StreamResponse[plugin_entities.ToolResponseChunk], error, +) { + runtime := plugin_manager.Get(session.PluginIdentity()) + if runtime == nil { + return nil, errors.New("plugin not found") + } + + response := stream.NewStreamResponse[plugin_entities.ToolResponseChunk](512) + + listener := runtime.Listen(session.ID()) + listener.AddListener(func(message []byte) { + chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message) + if err != nil { + log.Error("unmarshal json failed: %s", err.Error()) + return + } + + switch chunk.Type { + case plugin_entities.SESSION_MESSAGE_TYPE_STREAM: + chunk, err := parser.UnmarshalJsonBytes[plugin_entities.ToolResponseChunk](chunk.Data) + if err != nil { + log.Error("unmarshal json failed: %s", err.Error()) + return + } + response.Write(chunk) + case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE: + invokeDify(runtime, session, chunk.Data) + case plugin_entities.SESSION_MESSAGE_TYPE_END: + response.Close() + case plugin_entities.SESSION_MESSAGE_TYPE_ERROR: + e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data) + if err != nil { + break + } + response.WriteError(errors.New(e.Error)) + response.Close() + default: + response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type))) + response.Close() + } + }) + + response.OnClose(func() { + listener.Close() + }) + + runtime.Write(session.ID(), []byte(parser.MarshalJson( + getInvokeToolMap(session, PLUGIN_ACCESS_ACTION_INVOKE_TOOL, request)), + )) + + return response, nil +} diff --git a/internal/core/plugin_manager/aws_manager/run.go b/internal/core/plugin_manager/aws_manager/run.go index 3062138..667913a 100644 --- a/internal/core/plugin_manager/aws_manager/run.go +++ b/internal/core/plugin_manager/aws_manager/run.go @@ -4,3 +4,7 @@ func (r *AWSPluginRuntime) StartPlugin() error { return nil } + +func (r *AWSPluginRuntime) Wait() (<-chan bool, error) { + return nil, nil +} diff --git a/internal/core/plugin_manager/lifetime.go b/internal/core/plugin_manager/lifetime.go index 8e18d0d..29a083c 100644 --- a/internal/core/plugin_manager/lifetime.go +++ b/internal/core/plugin_manager/lifetime.go @@ -53,5 +53,14 @@ func lifetime(config *app.Config, r entities.PluginRuntimeInterface) { start_failed_times++ continue } + + // wait for plugin to stop + c, err := r.Wait() + if err == nil { + <-c + } + + // restart plugin in 5s + time.Sleep(5 * time.Second) } } diff --git a/internal/core/plugin_manager/local_manager/environment.go b/internal/core/plugin_manager/local_manager/environment.go index c5ec648..138fb08 100644 --- a/internal/core/plugin_manager/local_manager/environment.go +++ b/internal/core/plugin_manager/local_manager/environment.go @@ -19,7 +19,7 @@ func (r *LocalPluginRuntime) InitEnvironment() error { } // execute init command - handle := exec.Command("bash", "install.sh") + handle := exec.Command("bash", r.Config.Execution.Install) handle.Dir = r.State.RelativePath // get stdout and stderr diff --git a/internal/core/plugin_manager/local_manager/run.go b/internal/core/plugin_manager/local_manager/run.go index dddca7c..94e392f 100644 --- a/internal/core/plugin_manager/local_manager/run.go +++ b/internal/core/plugin_manager/local_manager/run.go @@ -1,6 +1,7 @@ package local_manager import ( + "errors" "fmt" "os/exec" "sync" @@ -11,36 +12,46 @@ import ( "github.com/langgenius/dify-plugin-daemon/internal/utils/routine" ) -func (r *LocalPluginRuntime) StartPlugin() error { +func (r *LocalPluginRuntime) gc() { + if r.io_identity != "" { + stdio_holder.Remove(r.io_identity) + } + + if r.w != nil { + close(r.w) + r.w = nil + } +} + +func (r *LocalPluginRuntime) init() { + r.w = make(chan bool) r.State.Status = entities.PLUGIN_RUNTIME_STATUS_LAUNCHING - defer func() { - r.io_identity = "" - }() +} + +func (r *LocalPluginRuntime) StartPlugin() error { defer log.Info("plugin %s stopped", r.Config.Identity()) + r.init() // start plugin - e := exec.Command("bash", "launch.sh") + e := exec.Command("bash", r.Config.Execution.Launch) e.Dir = r.State.RelativePath // get writer stdin, err := e.StdinPipe() if err != nil { r.State.Status = entities.PLUGIN_RUNTIME_STATUS_RESTARTING - e.Process.Kill() return fmt.Errorf("get stdin pipe failed: %s", err.Error()) } stdout, err := e.StdoutPipe() if err != nil { r.State.Status = entities.PLUGIN_RUNTIME_STATUS_RESTARTING - e.Process.Kill() return fmt.Errorf("get stdout pipe failed: %s", err.Error()) } stderr, err := e.StderrPipe() if err != nil { r.State.Status = entities.PLUGIN_RUNTIME_STATUS_RESTARTING - e.Process.Kill() return fmt.Errorf("get stderr pipe failed: %s", err.Error()) } @@ -56,17 +67,19 @@ func (r *LocalPluginRuntime) StartPlugin() error { r.State.Status = entities.PLUGIN_RUNTIME_STATUS_RESTARTING log.Error("plugin %s exited with error: %s", r.Config.Identity(), err.Error()) } + + r.gc() }() log.Info("plugin %s started", r.Config.Identity()) + // setup stdio stdio := stdio_holder.Put(r.Config.Identity(), stdin, stdout, stderr) - - // set io identity r.io_identity = stdio.GetID() + defer stdio.Stop() wg := sync.WaitGroup{} - wg.Add(1) + wg.Add(2) // listen to plugin stdout routine.Submit(func() { @@ -74,16 +87,30 @@ func (r *LocalPluginRuntime) StartPlugin() error { stdio.StartStdout() }) - err = stdio.StartStderr() + // listen to plugin stderr + routine.Submit(func() { + defer wg.Done() + stdio.StartStderr() + }) + + // wait for plugin to exit + err = stdio.Wait() if err != nil { - r.State.Status = entities.PLUGIN_RUNTIME_STATUS_RESTARTING - e.Process.Kill() return err } + e.Process.Kill() + wg.Wait() // plugin has exited r.State.Status = entities.PLUGIN_RUNTIME_STATUS_PENDING return nil } + +func (r *LocalPluginRuntime) Wait() (<-chan bool, error) { + if r.w == nil { + return nil, errors.New("plugin not started") + } + return r.w, nil +} diff --git a/internal/core/plugin_manager/local_manager/type.go b/internal/core/plugin_manager/local_manager/type.go index e571878..944a48f 100644 --- a/internal/core/plugin_manager/local_manager/type.go +++ b/internal/core/plugin_manager/local_manager/type.go @@ -6,4 +6,5 @@ type LocalPluginRuntime struct { entities.PluginRuntime io_identity string + w chan bool } diff --git a/internal/core/plugin_manager/stdio_holder/io.go b/internal/core/plugin_manager/stdio_holder/io.go index e951b11..d89cb11 100644 --- a/internal/core/plugin_manager/stdio_holder/io.go +++ b/internal/core/plugin_manager/stdio_holder/io.go @@ -2,11 +2,12 @@ package stdio_holder import ( "bufio" + "errors" "fmt" "io" "sync" + "time" - "github.com/google/uuid" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" "github.com/langgenius/dify-plugin-daemon/internal/utils/log" "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" @@ -19,200 +20,160 @@ var ( ) type stdioHolder struct { - id string - pluginIdentity string - writer io.WriteCloser - reader io.ReadCloser - errReader io.ReadCloser - l *sync.Mutex - listener map[string]func([]byte) - started bool - alive bool + id string + plugin_identity string + writer io.WriteCloser + reader io.ReadCloser + err_reader io.ReadCloser + l *sync.Mutex + listener map[string]func([]byte) + error_listener map[string]func([]byte) + started bool + + err_message string + last_err_message_updated_at time.Time + + health_chan chan bool + health_chan_closed bool + health_chan_lock *sync.Mutex + last_active_at time.Time +} + +func (s *stdioHolder) Error() error { + if time.Since(s.last_err_message_updated_at) < 60*time.Second { + if s.err_message != "" { + return errors.New(s.err_message) + } + } + + return nil } func (s *stdioHolder) Stop() { - s.alive = false s.writer.Close() s.reader.Close() - s.errReader.Close() + s.err_reader.Close() + + s.health_chan_lock.Lock() + if !s.health_chan_closed { + close(s.health_chan) + s.health_chan_closed = true + } + s.health_chan_lock.Unlock() stdio_holder.Delete(s.id) } func (s *stdioHolder) StartStdout() { s.started = true - s.alive = true + defer s.Stop() scanner := bufio.NewScanner(s.reader) - for s.alive { - for scanner.Scan() { - data := scanner.Bytes() - event, err := parser.UnmarshalJsonBytes[plugin_entities.PluginUniversalEvent](data) - if err != nil { - log.Error("unmarshal json failed: %s", err.Error()) - continue - } - - session_id := event.SessionId + for scanner.Scan() { + data := scanner.Bytes() + event, err := parser.UnmarshalJsonBytes[plugin_entities.PluginUniversalEvent](data) + if err != nil { + // log.Error("unmarshal json failed: %s", err.Error()) + continue + } - switch event.Event { - case plugin_entities.PLUGIN_EVENT_LOG: - if event.Event == plugin_entities.PLUGIN_EVENT_LOG { - logEvent, err := parser.UnmarshalJsonBytes[plugin_entities.PluginLogEvent](event.Data) - if err != nil { - log.Error("unmarshal json failed: %s", err.Error()) - continue - } + session_id := event.SessionId - log.Info("plugin %s: %s", s.pluginIdentity, logEvent.Message) - } - case plugin_entities.PLUGIN_EVENT_SESSION: - for _, listener := range listeners { - listener(s.id, event.Data) + switch event.Event { + case plugin_entities.PLUGIN_EVENT_LOG: + if event.Event == plugin_entities.PLUGIN_EVENT_LOG { + logEvent, err := parser.UnmarshalJsonBytes[plugin_entities.PluginLogEvent](event.Data) + if err != nil { + log.Error("unmarshal json failed: %s", err.Error()) + continue } - for listener_session_id, listener := range s.listener { - if listener_session_id == session_id { - listener(event.Data) - } + log.Info("plugin %s: %s", s.plugin_identity, logEvent.Message) + } + case plugin_entities.PLUGIN_EVENT_SESSION: + for _, listener := range listeners { + listener(s.id, event.Data) + } + + for listener_session_id, listener := range s.listener { + if listener_session_id == session_id { + listener(event.Data) } - case plugin_entities.PLUGIN_EVENT_ERROR: - log.Error("plugin %s: %s", s.pluginIdentity, event.Data) } + case plugin_entities.PLUGIN_EVENT_ERROR: + for listener_session_id, listener := range s.error_listener { + if listener_session_id == session_id { + listener(event.Data) + } + } + case plugin_entities.PLUGIN_EVENT_HEARTBEAT: + s.last_active_at = time.Now() } } } -/* - * @return error - */ -func (s *stdioHolder) StartStderr() error { - s.started = true - s.alive = true - defer s.Stop() - for s.alive { +func (s *stdioHolder) WriteError(msg string) { + const MAX_ERR_MSG_LEN = 1024 + reduce := len(msg) + len(s.err_message) - MAX_ERR_MSG_LEN + if reduce > 0 { + s.err_message = s.err_message[reduce:] + } + + s.err_message += msg + s.last_err_message_updated_at = time.Now() +} + +func (s *stdioHolder) StartStderr() { + for { buf := make([]byte, 1024) - n, err := s.errReader.Read(buf) + n, err := s.err_reader.Read(buf) if err != nil && err != io.EOF { - return err + break } else if err != nil { - return nil + s.WriteError(fmt.Sprintf("%s\n", buf[:n])) + break } if n > 0 { - return fmt.Errorf("stderr: %s", buf[:n]) + s.WriteError(fmt.Sprintf("%s\n", buf[:n])) } } - - return nil } -func (s *stdioHolder) GetID() string { - return s.id -} - -/* - * @param plugin_identity: string - * @param writer: io.WriteCloser - * @param reader: io.ReadCloser - * @param errReader: io.ReadCloser - */ -func Put( - plugin_identity string, - writer io.WriteCloser, - reader io.ReadCloser, - errReader io.ReadCloser, -) *stdioHolder { - id := uuid.New().String() - - holder := &stdioHolder{ - pluginIdentity: plugin_identity, - writer: writer, - reader: reader, - errReader: errReader, - id: id, - l: &sync.Mutex{}, +func (s *stdioHolder) Wait() error { + s.health_chan_lock.Lock() + if s.health_chan_closed { + s.health_chan_lock.Unlock() + return errors.New("you need to start the health check before waiting") } + s.health_chan_lock.Unlock() - stdio_holder.Store(id, holder) - return holder -} + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() -/* - * @param id: string - */ -func Get(id string) *stdioHolder { - if v, ok := stdio_holder.Load(id); ok { - if holder, ok := v.(*stdioHolder); ok { - return holder + // check status of plugin every 5 seconds + for { + s.health_chan_lock.Lock() + if s.health_chan_closed { + s.health_chan_lock.Unlock() + break } - } - - return nil -} - -/* - * @param id: string - */ -func Remove(id string) { - stdio_holder.Delete(id) -} - -/* - * @param id: string - * @param session_id: string - * @param listener: func(data []byte) - * @return string - listener identity - */ -func OnEvent(id string, session_id string, listener func([]byte)) { - if v, ok := stdio_holder.Load(id); ok { - if holder, ok := v.(*stdioHolder); ok { - holder.l.Lock() - defer holder.l.Unlock() - if holder.listener == nil { - holder.listener = map[string]func([]byte){} + s.health_chan_lock.Unlock() + select { + case <-ticker.C: + // check heartbeat + if time.Since(s.last_active_at) > 20*time.Second { + return errors.New("plugin is not active") } - - holder.listener[session_id] = listener + case <-s.health_chan: + // closed + return s.Error() } } -} -/* - * @param id: string - * @param listener: string - */ -func RemoveListener(id string, listener string) { - if v, ok := stdio_holder.Load(id); ok { - if holder, ok := v.(*stdioHolder); ok { - holder.l.Lock() - defer holder.l.Unlock() - delete(holder.listener, listener) - } - } -} - -/* - * @param listener: func(id string, data []byte) - */ -func OnGlobalEvent(listener func(string, []byte)) { - l.Lock() - defer l.Unlock() - listeners[uuid.New().String()] = listener + return nil } -/* - * @param id: string - * @param data: []byte - */ -func Write(id string, data []byte) error { - if v, ok := stdio_holder.Load(id); ok { - if holder, ok := v.(*stdioHolder); ok { - _, err := holder.writer.Write(data) - - return err - } - } - - return nil +func (s *stdioHolder) GetID() string { + return s.id } diff --git a/internal/core/plugin_manager/stdio_holder/store.go b/internal/core/plugin_manager/stdio_holder/store.go new file mode 100644 index 0000000..1a536f8 --- /dev/null +++ b/internal/core/plugin_manager/stdio_holder/store.go @@ -0,0 +1,102 @@ +package stdio_holder + +import ( + "io" + "sync" + + "github.com/google/uuid" +) + +func Put( + plugin_identity string, writer io.WriteCloser, + reader io.ReadCloser, err_reader io.ReadCloser, +) *stdioHolder { + id := uuid.New().String() + + holder := &stdioHolder{ + plugin_identity: plugin_identity, + writer: writer, + reader: reader, + err_reader: err_reader, + id: id, + l: &sync.Mutex{}, + + health_chan_lock: &sync.Mutex{}, + health_chan: make(chan bool), + } + + stdio_holder.Store(id, holder) + return holder +} + +func Get(id string) *stdioHolder { + if v, ok := stdio_holder.Load(id); ok { + if holder, ok := v.(*stdioHolder); ok { + return holder + } + } + + return nil +} + +func Remove(id string) { + stdio_holder.Delete(id) +} + +func OnEvent(id string, session_id string, listener func([]byte)) { + if v, ok := stdio_holder.Load(id); ok { + if holder, ok := v.(*stdioHolder); ok { + holder.l.Lock() + defer holder.l.Unlock() + if holder.listener == nil { + holder.listener = map[string]func([]byte){} + } + + holder.listener[session_id] = listener + } + } +} + +func OnError(id string, session_id string, listener func([]byte)) { + if v, ok := stdio_holder.Load(id); ok { + if holder, ok := v.(*stdioHolder); ok { + holder.l.Lock() + defer holder.l.Unlock() + if holder.error_listener == nil { + holder.error_listener = map[string]func([]byte){} + } + + holder.error_listener[session_id] = listener + } + } + +} + +func RemoveListener(id string, listener string) { + if v, ok := stdio_holder.Load(id); ok { + if holder, ok := v.(*stdioHolder); ok { + holder.l.Lock() + defer holder.l.Unlock() + delete(holder.listener, listener) + delete(holder.error_listener, listener) + } + } +} + +func OnGlobalEvent(listener func(string, []byte)) { + l.Lock() + defer l.Unlock() + listeners[uuid.New().String()] = listener +} + +func Write(id string, data []byte) error { + if v, ok := stdio_holder.Load(id); ok { + if holder, ok := v.(*stdioHolder); ok { + _, err := holder.writer.Write(data) + + return err + } + } + + return nil +} diff --git a/internal/core/plugin_manager/watcher.go b/internal/core/plugin_manager/watcher.go index 8468f4e..ad9d3c0 100644 --- a/internal/core/plugin_manager/watcher.go +++ b/internal/core/plugin_manager/watcher.go @@ -66,7 +66,7 @@ func loadNewPlugins(root_path string) <-chan entities.PluginRuntime { routine.Submit(func() { for _, plugin := range plugins { if plugin.IsDir() { - configuration_path := path.Join(root_path, plugin.Name(), "manifest.json") + configuration_path := path.Join(root_path, plugin.Name(), "manifest.yaml") configuration, err := parsePluginConfig(configuration_path) if err != nil { log.Error("parse plugin config error: %v", err) @@ -78,6 +78,10 @@ func loadNewPlugins(root_path string) <-chan entities.PluginRuntime { continue } + // check if .verified file exists + verified_path := path.Join(root_path, plugin.Name(), ".verified") + _, err = os.Stat(verified_path) + ch <- entities.PluginRuntime{ Config: *configuration, State: entities.PluginRuntimeState{ @@ -85,7 +89,7 @@ func loadNewPlugins(root_path string) <-chan entities.PluginRuntime { Restarts: 0, RelativePath: path.Join(root_path, plugin.Name()), ActiveAt: nil, - Verified: false, + Verified: err == nil, }, } } diff --git a/internal/server/controller.go b/internal/server/controller.go index f381aaa..57749f8 100644 --- a/internal/server/controller.go +++ b/internal/server/controller.go @@ -4,6 +4,7 @@ import ( "github.com/gin-gonic/gin" "github.com/langgenius/dify-plugin-daemon/internal/service" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests" ) func HealthCheck(c *gin.Context) { @@ -11,7 +12,7 @@ func HealthCheck(c *gin.Context) { } func InvokeTool(c *gin.Context) { - type request = plugin_entities.InvokePluginRequest[plugin_entities.InvokeToolRequest] + type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeTool] BindRequest[request]( c, @@ -20,3 +21,14 @@ func InvokeTool(c *gin.Context) { }, ) } + +func InvokeLLM(c *gin.Context) { + type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM] + + BindRequest[request]( + c, + func(itr request) { + service.InvokeLLM(&itr, c) + }, + ) +} diff --git a/internal/server/http.go b/internal/server/http.go index 80fa5f9..458ae10 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -12,6 +12,7 @@ func server(config *app.Config) { engine.GET("/health/check", HealthCheck) engine.POST("/plugin/tool/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTool) + engine.POST("/plugin/llm/invoke", CheckingKey(config.PluginInnerApiKey), InvokeLLM) engine.Run(fmt.Sprintf(":%d", config.SERVER_PORT)) } diff --git a/internal/service/invoke.go b/internal/service/invoke.go index b169a14..f89b92a 100644 --- a/internal/service/invoke.go +++ b/internal/service/invoke.go @@ -4,61 +4,28 @@ import ( "github.com/gin-gonic/gin" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon" "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" - "github.com/langgenius/dify-plugin-daemon/internal/types/entities" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests" "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" - "github.com/langgenius/dify-plugin-daemon/internal/utils/routine" + "github.com/langgenius/dify-plugin-daemon/internal/utils/stream" ) -func InvokeTool(r *plugin_entities.InvokePluginRequest[plugin_entities.InvokeToolRequest], ctx *gin.Context) { +func InvokeTool(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTool], ctx *gin.Context) { // create session session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) defer session.Close() - writer := ctx.Writer - writer.WriteHeader(200) - writer.Header().Set("Content-Type", "text/event-stream") - - done := make(chan bool) - - write_data := func(data interface{}) { - writer.WriteString("data: ") - writer.Write([]byte(parser.MarshalJson(data))) - writer.Write([]byte("\n\n")) - writer.Flush() - } - - plugin_daemon_response, err := plugin_daemon.InvokeTool( - session, - r.Data.ProviderName, - r.Data.ToolName, - r.Data.Parameters, - ) - - if err != nil { - write_data(entities.NewErrorResponse(-500, err.Error())) - close(done) - return - } - - routine.Submit(func() { - for plugin_daemon_response.Next() { - chunk, err := plugin_daemon_response.Read() - if err != nil { - break - } - write_data(entities.NewSuccessResponse(chunk)) - } - close(done) - }) - - select { - case <-writer.CloseNotify(): - plugin_daemon_response.Close() - case <-done: - } + baseSSEService(r, func() (*stream.StreamResponse[plugin_entities.ToolResponseChunk], error) { + return plugin_daemon.InvokeTool(session, &r.Data) + }, ctx) } -func InvokeModel(r *plugin_entities.InvokePluginRequest[plugin_entities.InvokeModelRequest], ctx *gin.Context) { +func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM], ctx *gin.Context) { + // create session + session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + defer session.Close() + baseSSEService(r, func() (*stream.StreamResponse[plugin_entities.InvokeModelResponseChunk], error) { + return plugin_daemon.InvokeLLM(session, &r.Data) + }, ctx) } diff --git a/internal/service/runner.go b/internal/service/runner.go new file mode 100644 index 0000000..b39d114 --- /dev/null +++ b/internal/service/runner.go @@ -0,0 +1,80 @@ +package service + +import ( + "sync/atomic" + "time" + + "github.com/gin-gonic/gin" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" + "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" + "github.com/langgenius/dify-plugin-daemon/internal/utils/routine" + "github.com/langgenius/dify-plugin-daemon/internal/utils/stream" +) + +func baseSSEService[T any, R any]( + r *plugin_entities.InvokePluginRequest[T], + generator func() (*stream.StreamResponse[R], error), + ctx *gin.Context, +) { + writer := ctx.Writer + writer.WriteHeader(200) + writer.Header().Set("Content-Type", "text/event-stream") + + done := make(chan bool) + done_closed := new(int32) + + write_data := func(data interface{}) { + writer.Write([]byte("data: ")) + writer.Write(parser.MarshalJsonBytes(data)) + writer.Write([]byte("\n\n")) + writer.Flush() + } + + plugin_daemon_response, err := generator() + last_response_at := time.Now() + + if err != nil { + write_data(entities.NewErrorResponse(-500, err.Error())) + close(done) + return + } + + routine.Submit(func() { + for plugin_daemon_response.Next() { + last_response_at = time.Now() + chunk, err := plugin_daemon_response.Read() + if err != nil { + write_data(entities.NewErrorResponse(-500, err.Error())) + break + } + write_data(entities.NewSuccessResponse(chunk)) + } + + if atomic.CompareAndSwapInt32(done_closed, 0, 1) { + close(done) + } + }) + + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for { + select { + case <-writer.CloseNotify(): + plugin_daemon_response.Close() + return + case <-done: + return + case <-ticker.C: + if time.Since(last_response_at) > 30*time.Second { + write_data(entities.NewErrorResponse(-500, "killed by timeout")) + if atomic.CompareAndSwapInt32(done_closed, 0, 1) { + close(done) + } + return + } + } + + } +} diff --git a/internal/types/entities/model_entities/llm.go b/internal/types/entities/model_entities/llm.go index 31a7a05..13ad677 100644 --- a/internal/types/entities/model_entities/llm.go +++ b/internal/types/entities/model_entities/llm.go @@ -1,11 +1,23 @@ package model_entities +import ( + "encoding/json" + "errors" + + "github.com/go-playground/validator/v10" + "github.com/langgenius/dify-plugin-daemon/internal/types/validators" + "github.com/shopspring/decimal" +) + type ModelType string const ( MODEL_TYPE_LLM ModelType = "llm" MODEL_TYPE_TEXT_EMBEDDING ModelType = "text_embedding" MODEL_TYPE_RERANKING ModelType = "rerank" + MODEL_TYPE_SPEECH2TEXT ModelType = "speech2text" + MODEL_TYPE_TTS ModelType = "tts" + MODEL_TYPE_MODERATION ModelType = "moderation" ) type LLMModel string @@ -14,3 +26,234 @@ const ( LLM_MODE_CHAT LLMModel = "chat" LLM_MODE_COMPLETION LLMModel = "completion" ) + +type PromptMessageRole string + +const ( + PROMPT_MESSAGE_ROLE_SYSTEM = "system" + PROMPT_MESSAGE_ROLE_USER = "user" + PROMPT_MESSAGE_ROLE_ASSISTANT = "assistant" + PROMPT_MESSAGE_ROLE_TOOL = "tool" +) + +func isPromptMessageRole(fl validator.FieldLevel) bool { + value := fl.Field().String() + switch value { + case string(PROMPT_MESSAGE_ROLE_SYSTEM), + string(PROMPT_MESSAGE_ROLE_USER), + string(PROMPT_MESSAGE_ROLE_ASSISTANT), + string(PROMPT_MESSAGE_ROLE_TOOL): + return true + } + return false +} + +type PromptMessage struct { + Role PromptMessageRole `json:"role" validate:"required,prompt_message_role"` + Content any `json:"content" validate:"required,prompt_message_content"` + Name string `json:"name"` + ToolCalls []PromptMessageToolCall `json:"tool_calls" validate:"dive"` + ToolCallId string `json:"tool_call_id"` +} + +func isPromptMessageContent(fl validator.FieldLevel) bool { + // only allow string or []PromptMessageContent + value := fl.Field().Interface() + switch value_type := value.(type) { + case string: + return true + case []PromptMessageContent: + // validate the content + for _, content := range value_type { + if err := validators.GlobalEntitiesValidator.Struct(content); err != nil { + return false + } + } + return true + } + return false +} + +type PromptMessageContentType string + +const ( + PROMPT_MESSAGE_CONTENT_TYPE_TEXT PromptMessageContentType = "text" + PROMPT_MESSAGE_CONTENT_TYPE_IMAGE PromptMessageContentType = "image" +) + +func isPromptMessageContentType(fl validator.FieldLevel) bool { + value := fl.Field().String() + switch value { + case string(PROMPT_MESSAGE_CONTENT_TYPE_TEXT), + string(PROMPT_MESSAGE_CONTENT_TYPE_IMAGE): + return true + } + return false +} + +type PromptMessageContent struct { + Type PromptMessageContentType `json:"type" validate:"required,prompt_message_content_type"` + Data string `json:"data" validate:"required"` +} + +type PromptMessageToolCall struct { + // TODO: +} + +func init() { + validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_role", isPromptMessageRole) + validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_content", isPromptMessageContent) + validators.GlobalEntitiesValidator.RegisterValidation("prompt_message_content_type", isPromptMessageContentType) +} + +func (p *PromptMessage) UnmarshalJSON(data []byte) error { + // Unmarshal the JSON data into a map + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Check if content is a string or an array which contains type and content + if _, ok := raw["content"]; ok { + var content string + if err := json.Unmarshal(raw["content"], &content); err == nil { + p.Content = content + } else { + var content []PromptMessageContent + if err := json.Unmarshal(raw["content"], &content); err != nil { + return err + } + p.Content = content + } + } else { + return errors.New("content field is required") + } + + // Unmarshal the rest of the fields + if role, ok := raw["role"]; ok { + if err := json.Unmarshal(role, &p.Role); err != nil { + return err + } + } else { + return errors.New("role field is required") + } + + if name, ok := raw["name"]; ok { + if err := json.Unmarshal(name, &p.Name); err != nil { + return err + } + } + + if tool_calls, ok := raw["tool_calls"]; ok { + if err := json.Unmarshal(tool_calls, &p.ToolCalls); err != nil { + return err + } + } + + if tool_call_id, ok := raw["tool_call_id"]; ok { + if err := json.Unmarshal(tool_call_id, &p.ToolCallId); err != nil { + return err + } + } + + // Validate the struct + if err := validators.GlobalEntitiesValidator.Struct(p); err != nil { + return err + } + + // validate tool call id + if p.Role == PROMPT_MESSAGE_ROLE_TOOL && p.ToolCallId == "" { + return errors.New("tool call id is required") + } + + return nil +} + +type PromptMessageTool struct { + Name string `json:"name" validate:"required"` + Description string `json:"description" validate:"required"` + Parameters map[string]any `json:"parameters"` +} + +func (p *PromptMessageTool) UnmarshalJSON(data []byte) error { + type Alias PromptMessageTool + aux := &struct { + *Alias + }{ + Alias: (*Alias)(p), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + if err := validators.GlobalEntitiesValidator.Struct(p); err != nil { + return err + } + + return nil +} + +type LLMResultChunk struct { + Model LLMModel `json:"model" validate:"required"` + PromptMessages []PromptMessage `json:"prompt_messages" validate:"required,dive"` + SystemFingerprint string `json:"system_fingerprint" validate:"omitempty"` + Delta LLMResultChunkDelta `json:"delta" validate:"required"` +} + +func (l *LLMResultChunk) UnmarshalJSON(data []byte) error { + type Alias LLMResultChunk + aux := &struct { + *Alias + }{ + Alias: (*Alias)(l), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + if err := validators.GlobalEntitiesValidator.Struct(l); err != nil { + return err + } + + return nil +} + +type LLMUsage struct { + PromptTokens *int `json:"prompt_tokens" validate:"required"` + PromptUnitPrice decimal.Decimal `json:"prompt_unit_price" validate:"required"` + PromptPriceUnit decimal.Decimal `json:"prompt_price_unit" validate:"required"` + PromptPrice decimal.Decimal `json:"prompt_price" validate:"required"` + CompletionTokens *int `json:"completion_tokens" validate:"required"` + CompletionUnitPrice decimal.Decimal `json:"completion_unit_price" validate:"required"` + CompletionPriceUnit decimal.Decimal `json:"completion_price_unit" validate:"required"` + CompletionPrice decimal.Decimal `json:"completion_price" validate:"required"` + TotalTokens *int `json:"total_tokens" validate:"required"` + TotalPrice decimal.Decimal `json:"total_price" validate:"required"` + Currency *string `json:"currency" validate:"required"` + Latency *float64 `json:"latency" validate:"required"` +} + +func (l *LLMUsage) UnmarshalJSON(data []byte) error { + type Alias LLMUsage + aux := &struct { + *Alias + }{ + Alias: (*Alias)(l), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + if err := validators.GlobalEntitiesValidator.Struct(l); err != nil { + return err + } + + return nil +} + +type LLMResultChunkDelta struct { + Index *int `json:"index" validate:"required"` + Message PromptMessage `json:"message" validate:"required"` + Usage *LLMUsage `json:"usage" validate:"omitempty"` + FinishReason *string `json:"finish_reason" validate:"omitempty"` +} diff --git a/internal/types/entities/model_entities/llm_test.go b/internal/types/entities/model_entities/llm_test.go new file mode 100644 index 0000000..78fec14 --- /dev/null +++ b/internal/types/entities/model_entities/llm_test.go @@ -0,0 +1,278 @@ +package model_entities + +import ( + "encoding/json" + "testing" +) + +func TestFullFunctionPromptMessage(t *testing.T) { + const ( + system_message = ` + { + "role": "system", + "content": "you are a helpful assistant" + } + ` + user_message = ` + { + "role": "user", + "content": "hello" + }` + assistant_message = ` + { + "role": "assistant", + "content": "you are a helpful assistant" + }` + image_message = ` + { + "role": "user", + "content": [ + { + "type": "image", + "data": "base64" + } + ] + }` + tool_message = ` + { + "role": "tool", + "content": "you are a helpful assistant", + "tool_call_id": "123" + } + ` + ) + + var prompt_message PromptMessage + + err := json.Unmarshal([]byte(system_message), &prompt_message) + if err != nil { + t.Error(err) + } + if prompt_message.Role != "system" { + t.Error("role is not system") + } + + err = json.Unmarshal([]byte(user_message), &prompt_message) + if err != nil { + t.Error(err) + } + if prompt_message.Role != "user" { + t.Error("role is not user") + } + + err = json.Unmarshal([]byte(assistant_message), &prompt_message) + if err != nil { + t.Error(err) + } + if prompt_message.Role != "assistant" { + t.Error("role is not assistant") + } + + err = json.Unmarshal([]byte(image_message), &prompt_message) + if err != nil { + t.Error(err) + } + if prompt_message.Role != "user" { + t.Error("role is not user") + } + if prompt_message.Content.([]PromptMessageContent)[0].Type != "image" { + t.Error("type is not image") + } + + err = json.Unmarshal([]byte(tool_message), &prompt_message) + if err != nil { + t.Error(err) + } + if prompt_message.Role != "tool" { + t.Error("role is not tool") + } + if prompt_message.ToolCallId != "123" { + t.Error("tool call id is not 123") + } +} + +func TestWrongRole(t *testing.T) { + const ( + wrong_role = ` + { + "role": "wrong", + "content": "you are a helpful assistant" + } + ` + ) + + var prompt_message PromptMessage + + err := json.Unmarshal([]byte(wrong_role), &prompt_message) + if err == nil { + t.Error("error is nil") + } +} + +func TestWrongContent(t *testing.T) { + const ( + wrong_content = ` + { + "role": "user", + "content": 123 + } + ` + ) + + var prompt_message PromptMessage + + err := json.Unmarshal([]byte(wrong_content), &prompt_message) + if err == nil { + t.Error("error is nil") + } +} + +func TestWrongContentArray(t *testing.T) { + const ( + wrong_content_array = ` + { + "role": "user", + "content": [ + { + "type": "image", + "data": 123 + } + ] + } + ` + ) + + var prompt_message PromptMessage + + err := json.Unmarshal([]byte(wrong_content_array), &prompt_message) + if err == nil { + t.Error("error is nil") + } +} + +func TestWrongContentArray2(t *testing.T) { + const ( + wrong_content_array2 = ` + { + "role": "user", + "content": [ + { + "type": "image" + } + ] + } + ` + ) + + var prompt_message PromptMessage + + err := json.Unmarshal([]byte(wrong_content_array2), &prompt_message) + if err == nil { + t.Error("error is nil") + } +} + +func TestWrongContentArray3(t *testing.T) { + const ( + wrong_content_array3 = ` + { + "role": "user", + "content": [ + { + "type": "wwww", + "data": "base64" + }, + { + "type": "image", + "data": "base64" + } + ] + } + ` + ) + + var prompt_message PromptMessage + + err := json.Unmarshal([]byte(wrong_content_array3), &prompt_message) + if err == nil { + t.Error("error is nil") + } +} + +func TestFullFunctionLLMResultChunk(t *testing.T) { + const ( + llm_result_chunk = ` + { + "model": "gpt-3.5-turbo", + "prompt_messages": [ + { + "role": "system", + "content": "you are a helpful assistant" + }, + { + "role": "user", + "content": "hello" + } + ], + "system_fingerprint": "123", + "delta": { + "index" : 0, + "message" : { + "role": "assistant", + "content": "I am a helpful assistant" + }, + "usage": { + "prompt_tokens": 10, + "prompt_unit_price": 0.1, + "prompt_price_unit": 1, + "prompt_price": 1, + "completion_tokens": 10, + "completion_unit_price": 0.1, + "completion_price_unit": 1, + "completion_price": 1, + "total_tokens": 20, + "total_price": 2, + "currency": "usd", + "latency": 0.1 + }, + "finish_reason": "completed" + } + } + ` + ) + + var c LLMResultChunk + + err := json.Unmarshal([]byte(llm_result_chunk), &c) + if err != nil { + t.Error(err) + } +} + +func TestZeroLLMUsage(t *testing.T) { + const ( + llm_usage = ` + { + "prompt_tokens": 0, + "prompt_unit_price": 0, + "prompt_price_unit": 0, + "prompt_price": 0, + "completion_tokens": 0, + "completion_unit_price": 0, + "completion_price_unit": 0, + "completion_price": 0, + "total_tokens": 0, + "total_price": 0, + "currency": "usd", + "latency": 0 + } + ` + ) + + var u LLMUsage + + err := json.Unmarshal([]byte(llm_usage), &u) + if err != nil { + t.Error(err) + } +} diff --git a/internal/types/entities/plugin_entities/basic_type_test.go b/internal/types/entities/plugin_entities/basic_type_test.go index 5a61f92..a9bf3ba 100644 --- a/internal/types/entities/plugin_entities/basic_type_test.go +++ b/internal/types/entities/plugin_entities/basic_type_test.go @@ -1,6 +1,10 @@ package plugin_entities -import "testing" +import ( + "testing" + + "github.com/langgenius/dify-plugin-daemon/internal/types/validators" +) func TestGenericType_Validate(t *testing.T) { type F struct { @@ -13,7 +17,7 @@ func TestGenericType_Validate(t *testing.T) { }, } - if err := global_tool_provider_validator.Struct(f); err != nil { + if err := validators.GlobalEntitiesValidator.Struct(f); err != nil { t.Errorf("GenericType_Validate() error = %v", err) } } @@ -29,7 +33,7 @@ func TestWrongGenericType_Validate(t *testing.T) { }, } - if err := global_tool_provider_validator.Struct(f); err == nil { + if err := validators.GlobalEntitiesValidator.Struct(f); err == nil { t.Error("WrongGenericType_Validate() error = nil; want error") } } diff --git a/internal/types/entities/plugin_entities/event.go b/internal/types/entities/plugin_entities/event.go index 6d66ed8..ffd4c8b 100644 --- a/internal/types/entities/plugin_entities/event.go +++ b/internal/types/entities/plugin_entities/event.go @@ -1,6 +1,10 @@ package plugin_entities -import "encoding/json" +import ( + "encoding/json" + + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities" +) type PluginUniversalEvent struct { Event PluginEventType `json:"event"` @@ -11,9 +15,10 @@ type PluginUniversalEvent struct { type PluginEventType string const ( - PLUGIN_EVENT_LOG PluginEventType = "log" - PLUGIN_EVENT_SESSION PluginEventType = "session" - PLUGIN_EVENT_ERROR PluginEventType = "error" + PLUGIN_EVENT_LOG PluginEventType = "log" + PLUGIN_EVENT_SESSION PluginEventType = "session" + PLUGIN_EVENT_ERROR PluginEventType = "error" + PLUGIN_EVENT_HEARTBEAT PluginEventType = "heartbeat" ) type PluginLogEvent struct { @@ -32,13 +37,22 @@ type SESSION_MESSAGE_TYPE string const ( SESSION_MESSAGE_TYPE_STREAM SESSION_MESSAGE_TYPE = "stream" SESSION_MESSAGE_TYPE_END SESSION_MESSAGE_TYPE = "end" + SESSION_MESSAGE_TYPE_ERROR SESSION_MESSAGE_TYPE = "error" SESSION_MESSAGE_TYPE_INVOKE SESSION_MESSAGE_TYPE = "invoke" ) -type InvokeToolResponseChunk struct { +type ToolResponseChunk struct { + Type string `json:"type"` + Message map[string]any `json:"message"` +} + +type PluginResponseChunk struct { Type string `json:"type"` Data json.RawMessage `json:"data"` } -type InvokeModelResponseChunk struct { +type InvokeModelResponseChunk = model_entities.LLMResultChunk + +type ErrorResponse struct { + Error string `json:"error"` } diff --git a/internal/types/entities/plugin_entities/model_configuration.go b/internal/types/entities/plugin_entities/model_configuration.go index ca13265..f4179ce 100644 --- a/internal/types/entities/plugin_entities/model_configuration.go +++ b/internal/types/entities/plugin_entities/model_configuration.go @@ -7,6 +7,7 @@ import ( ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" en_translations "github.com/go-playground/validator/v10/translations/en" + "github.com/langgenius/dify-plugin-daemon/internal/types/validators" "github.com/shopspring/decimal" ) @@ -213,20 +214,16 @@ type ModelProviderConfiguration struct { ModelCredentialSchema *ModelCredentialSchema `json:"model_credential_schema" validate:"omitempty"` } -var ( - global_model_provider_validator = validator.New() -) - func init() { // init validator en := en.New() uni := ut.New(en, en) translator, _ := uni.GetTranslator("en") // register translations for default validators - en_translations.RegisterDefaultTranslations(global_model_provider_validator, translator) + en_translations.RegisterDefaultTranslations(validators.GlobalEntitiesValidator, translator) - global_model_provider_validator.RegisterValidation("model_type", isModelType) - global_model_provider_validator.RegisterTranslation( + validators.GlobalEntitiesValidator.RegisterValidation("model_type", isModelType) + validators.GlobalEntitiesValidator.RegisterTranslation( "model_type", translator, func(ut ut.Translator) error { @@ -238,8 +235,8 @@ func init() { }, ) - global_model_provider_validator.RegisterValidation("model_provider_configurate_method", isModelProviderConfigurateMethod) - global_model_provider_validator.RegisterTranslation( + validators.GlobalEntitiesValidator.RegisterValidation("model_provider_configurate_method", isModelProviderConfigurateMethod) + validators.GlobalEntitiesValidator.RegisterTranslation( "model_provider_configurate_method", translator, func(ut ut.Translator) error { @@ -251,8 +248,8 @@ func init() { }, ) - global_model_provider_validator.RegisterValidation("model_provider_form_type", isModelProviderFormType) - global_model_provider_validator.RegisterTranslation( + validators.GlobalEntitiesValidator.RegisterValidation("model_provider_form_type", isModelProviderFormType) + validators.GlobalEntitiesValidator.RegisterTranslation( "model_provider_form_type", translator, func(ut ut.Translator) error { @@ -264,8 +261,8 @@ func init() { }, ) - global_model_provider_validator.RegisterValidation("model_parameter_type", isModelParameterType) - global_model_provider_validator.RegisterTranslation( + validators.GlobalEntitiesValidator.RegisterValidation("model_parameter_type", isModelParameterType) + validators.GlobalEntitiesValidator.RegisterTranslation( "model_parameter_type", translator, func(ut ut.Translator) error { @@ -277,9 +274,9 @@ func init() { }, ) - global_model_provider_validator.RegisterValidation("parameter_rule", isParameterRule) + validators.GlobalEntitiesValidator.RegisterValidation("parameter_rule", isParameterRule) - global_model_provider_validator.RegisterValidation("is_basic_type", isGenericType) + validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType) } func UnmarshalModelProviderConfiguration(data []byte) (*ModelProviderConfiguration, error) { @@ -289,7 +286,7 @@ func UnmarshalModelProviderConfiguration(data []byte) (*ModelProviderConfigurati return nil, err } - err = global_model_provider_validator.Struct(modelProviderConfiguration) + err = validators.GlobalEntitiesValidator.Struct(modelProviderConfiguration) if err != nil { return nil, err } diff --git a/internal/types/entities/plugin_entities/plugin_declaration.go b/internal/types/entities/plugin_entities/plugin_declaration.go index 8185174..aac4108 100644 --- a/internal/types/entities/plugin_entities/plugin_declaration.go +++ b/internal/types/entities/plugin_entities/plugin_declaration.go @@ -67,14 +67,20 @@ type PluginDeclarationMeta struct { Arch []string `json:"arch" yaml:"arch" validate:"required,dive,plugin_declaration_platform_arch"` } +type PluginDeclarationExecution struct { + Install string `json:"install" yaml:"install" validate:"omitempty"` + Launch string `json:"launch" yaml:"launch" validate:"omitempty"` +} + type PluginDeclaration struct { - Version string `json:"version" yaml:"version" validate:"required"` - Type DifyManifestType `json:"type" yaml:"type" validate:"required,eq=plugin"` - Author string `json:"author" yaml:"author" validate:"required"` - Name string `json:"name" yaml:"name" validate:"required" enum:"plugin"` - CreatedAt time.Time `json:"created_at" yaml:"created_at" validate:"required"` - Resource PluginResourceRequirement `json:"resource" yaml:"resource" validate:"required"` - Plugins []string `json:"plugins" yaml:"plugins" validate:"required"` + Version string `json:"version" yaml:"version" validate:"required"` + Type DifyManifestType `json:"type" yaml:"type" validate:"required,eq=plugin"` + Author string `json:"author" yaml:"author" validate:"required"` + Name string `json:"name" yaml:"name" validate:"required" enum:"plugin"` + CreatedAt time.Time `json:"created_at" yaml:"created_at" validate:"required"` + Resource PluginResourceRequirement `json:"resource" yaml:"resource" validate:"required"` + Plugins []string `json:"plugins" yaml:"plugins" validate:"required"` + Execution PluginDeclarationExecution `json:"execution" yaml:"execution" validate:"required"` } func (p *PluginDeclaration) Identity() string { diff --git a/internal/types/entities/plugin_entities/request.go b/internal/types/entities/plugin_entities/request.go index 01d0df7..108b606 100644 --- a/internal/types/entities/plugin_entities/request.go +++ b/internal/types/entities/plugin_entities/request.go @@ -1,24 +1,9 @@ package plugin_entities -type InvokePluginRequestData interface { - InvokeToolRequest | InvokeModelRequest -} - -type InvokeModelRequest struct { -} - -type InvokePluginRequest[T InvokePluginRequestData] struct { +type InvokePluginRequest[T any] struct { PluginName string `json:"plugin_name" binding:"required"` PluginVersion string `json:"plugin_version" binding:"required"` TenantId string `json:"tenant_id" binding:"required"` UserId string `json:"user_id" binding:"required"` Data T `json:"data" binding:"required"` } - -type InvokeToolRequest struct { - ProviderName string `json:"provider_name" binding:"required"` - ToolName string `json:"tool_name" binding:"required"` - ToolRuntime struct { - } `json:"tool_runtime" binding:"required"` - Parameters map[string]interface{} `json:"parameters" binding:"required"` -} diff --git a/internal/types/entities/plugin_entities/tool_configuration.go b/internal/types/entities/plugin_entities/tool_configuration.go index 4fcc2f0..4ae549b 100644 --- a/internal/types/entities/plugin_entities/tool_configuration.go +++ b/internal/types/entities/plugin_entities/tool_configuration.go @@ -1,12 +1,14 @@ package plugin_entities import ( + "encoding/json" "fmt" "github.com/go-playground/locales/en" ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" en_translations "github.com/go-playground/validator/v10/translations/en" + "github.com/langgenius/dify-plugin-daemon/internal/types/validators" "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" ) @@ -188,20 +190,16 @@ type ToolProviderConfiguration struct { Tools []ToolConfiguration `json:"tools" validate:"required,dive"` } -var ( - global_tool_provider_validator = validator.New() -) - func init() { // init validator en := en.New() uni := ut.New(en, en) translator, _ := uni.GetTranslator("en") // register translations for default validators - en_translations.RegisterDefaultTranslations(global_tool_provider_validator, translator) + en_translations.RegisterDefaultTranslations(validators.GlobalEntitiesValidator, translator) - global_tool_provider_validator.RegisterValidation("tool_parameter_type", isToolParameterType) - global_tool_provider_validator.RegisterTranslation( + validators.GlobalEntitiesValidator.RegisterValidation("tool_parameter_type", isToolParameterType) + validators.GlobalEntitiesValidator.RegisterTranslation( "tool_parameter_type", translator, func(ut ut.Translator) error { @@ -213,8 +211,8 @@ func init() { }, ) - global_tool_provider_validator.RegisterValidation("tool_parameter_form", isToolParameterForm) - global_tool_provider_validator.RegisterTranslation( + validators.GlobalEntitiesValidator.RegisterValidation("tool_parameter_form", isToolParameterForm) + validators.GlobalEntitiesValidator.RegisterTranslation( "tool_parameter_form", translator, func(ut ut.Translator) error { @@ -226,8 +224,8 @@ func init() { }, ) - global_tool_provider_validator.RegisterValidation("credential_type", isCredentialType) - global_tool_provider_validator.RegisterTranslation( + validators.GlobalEntitiesValidator.RegisterValidation("credential_type", isCredentialType) + validators.GlobalEntitiesValidator.RegisterTranslation( "credential_type", translator, func(ut ut.Translator) error { @@ -239,8 +237,8 @@ func init() { }, ) - global_tool_provider_validator.RegisterValidation("tool_label", isToolLabel) - global_tool_provider_validator.RegisterTranslation( + validators.GlobalEntitiesValidator.RegisterValidation("tool_label", isToolLabel) + validators.GlobalEntitiesValidator.RegisterTranslation( "tool_label", translator, func(ut ut.Translator) error { @@ -252,7 +250,25 @@ func init() { }, ) - global_tool_provider_validator.RegisterValidation("is_basic_type", isGenericType) + validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType) +} + +func (t *ToolProviderConfiguration) UnmarshalJSON(data []byte) error { + type Alias ToolProviderConfiguration + aux := &struct { + *Alias + }{ + Alias: (*Alias)(t), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + if err := validators.GlobalEntitiesValidator.Struct(t); err != nil { + return err + } + + return nil } func UnmarshalToolProviderConfiguration(data []byte) (*ToolProviderConfiguration, error) { @@ -261,7 +277,7 @@ func UnmarshalToolProviderConfiguration(data []byte) (*ToolProviderConfiguration return nil, fmt.Errorf("failed to unmarshal tool provider configuration: %w", err) } - if err := global_tool_provider_validator.Struct(obj); err != nil { + if err := validators.GlobalEntitiesValidator.Struct(obj); err != nil { return nil, fmt.Errorf("failed to validate tool provider configuration: %w", err) } diff --git a/internal/types/entities/requests/model.go b/internal/types/entities/requests/model.go new file mode 100644 index 0000000..6643501 --- /dev/null +++ b/internal/types/entities/requests/model.go @@ -0,0 +1,38 @@ +package requests + +import ( + "encoding/json" + + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities" + "github.com/langgenius/dify-plugin-daemon/internal/types/validators" +) + +type RequestInvokeLLM struct { + Provider string `json:"provider"` + ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type"` + Model string `json:"model"` + ModelParameters map[string]any `json:"model_parameters" validate:"omitempty,dive,is_basic_type"` + PromptMessages []model_entities.PromptMessage `json:"prompt_messages" validate:"omitempty,dive"` + Tools []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"` + Stop []string `json:"stop" validate:"omitempty"` + Stream bool `json:"stream"` + Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"` +} + +func (r *RequestInvokeLLM) UnmarshalJSON(data []byte) error { + type Alias RequestInvokeLLM + aux := &struct { + *Alias + }{ + Alias: (*Alias)(r), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + if err := validators.GlobalEntitiesValidator.Struct(r); err != nil { + return err + } + + return nil +} diff --git a/internal/types/entities/requests/tool.go b/internal/types/entities/requests/tool.go new file mode 100644 index 0000000..43d47d5 --- /dev/null +++ b/internal/types/entities/requests/tool.go @@ -0,0 +1,32 @@ +package requests + +import ( + "encoding/json" + + "github.com/langgenius/dify-plugin-daemon/internal/types/validators" +) + +type RequestInvokeTool struct { + Provider string `json:"provider" validate:"required"` + Tool string `json:"tool" validate:"required"` + ToolParameters map[string]any `json:"tool_parameters" validate:"omitempty,dive,is_basic_type"` + Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"` +} + +func (r *RequestInvokeTool) UnmarshalJSON(data []byte) error { + type Alias RequestInvokeTool + aux := &struct { + *Alias + }{ + Alias: (*Alias)(r), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + if err := validators.GlobalEntitiesValidator.Struct(r); err != nil { + return err + } + + return nil +} diff --git a/internal/types/entities/runtime.go b/internal/types/entities/runtime.go index c18b1a5..70cbfc9 100644 --- a/internal/types/entities/runtime.go +++ b/internal/types/entities/runtime.go @@ -25,6 +25,7 @@ type ( Stop() Configuration() *plugin_entities.PluginDeclaration RuntimeState() *PluginRuntimeState + Wait() (<-chan bool, error) } PluginRuntimeSessionIOInterface interface { diff --git a/internal/types/validators/validators.go b/internal/types/validators/validators.go new file mode 100644 index 0000000..bd22a4c --- /dev/null +++ b/internal/types/validators/validators.go @@ -0,0 +1,7 @@ +package validators + +import "github.com/go-playground/validator/v10" + +var ( + GlobalEntitiesValidator = validator.New() +) diff --git a/internal/utils/cache/redis.go b/internal/utils/cache/redis.go index 9f724ad..03358df 100644 --- a/internal/utils/cache/redis.go +++ b/internal/utils/cache/redis.go @@ -151,7 +151,7 @@ func GetMap[V any](key string) (map[string]V, error) { for k, v := range val { value, err := parser.UnmarshalJson[V](v) if err != nil { - return nil, err + continue } result[k] = value diff --git a/internal/utils/requests/http_options.go b/internal/utils/http_requests/http_options.go similarity index 98% rename from internal/utils/requests/http_options.go rename to internal/utils/http_requests/http_options.go index 7240136..251f8db 100644 --- a/internal/utils/requests/http_options.go +++ b/internal/utils/http_requests/http_options.go @@ -1,4 +1,4 @@ -package requests +package http_requests type HttpOptions struct { Type string diff --git a/internal/utils/requests/http_request.go b/internal/utils/http_requests/http_request.go similarity index 98% rename from internal/utils/requests/http_request.go rename to internal/utils/http_requests/http_request.go index 5e08c5a..d46ce54 100644 --- a/internal/utils/requests/http_request.go +++ b/internal/utils/http_requests/http_request.go @@ -1,4 +1,4 @@ -package requests +package http_requests import ( "bytes" diff --git a/internal/utils/requests/http_warpper.go b/internal/utils/http_requests/http_warpper.go similarity index 99% rename from internal/utils/requests/http_warpper.go rename to internal/utils/http_requests/http_warpper.go index 68adeb7..de7da8e 100644 --- a/internal/utils/requests/http_warpper.go +++ b/internal/utils/http_requests/http_warpper.go @@ -1,4 +1,4 @@ -package requests +package http_requests import ( "bufio" diff --git a/internal/utils/stream/response.go b/internal/utils/stream/response.go index 62eadc5..f6f8f4e 100644 --- a/internal/utils/stream/response.go +++ b/internal/utils/stream/response.go @@ -15,6 +15,7 @@ type StreamResponse[T any] struct { max int listening bool onClose func() + err error } func NewStreamResponse[T any](max int) *StreamResponse[T] { @@ -31,12 +32,12 @@ func (r *StreamResponse[T]) OnClose(f func()) { func (r *StreamResponse[T]) Next() bool { r.l.Lock() - if r.closed { + if r.closed && r.q.Len() == 0 && r.err == nil { r.l.Unlock() return false } - if r.q.Len() > 0 { + if r.q.Len() > 0 || r.err != nil { r.l.Unlock() return true } @@ -59,7 +60,13 @@ func (r *StreamResponse[T]) Read() (T, error) { return data, nil } else { var data T - return data, errors.New("no data available, please call Next() to wait for data") + if r.err != nil { + err := r.err + r.err = nil + return data, err + } + + return data, errors.New("no data available") } } @@ -117,3 +124,16 @@ func (r *StreamResponse[T]) Size() int { return r.q.Len() } + +func (r *StreamResponse[T]) WriteError(err error) { + r.l.Lock() + defer r.l.Unlock() + + r.err = err + + if r.q.Len() == 0 { + if r.listening { + r.sig <- true + } + } +} diff --git a/internal/utils/stream/response_test.go b/internal/utils/stream/response_test.go new file mode 100644 index 0000000..68cf44b --- /dev/null +++ b/internal/utils/stream/response_test.go @@ -0,0 +1,55 @@ +package stream + +import ( + "errors" + "testing" + "time" +) + +func TestStreamGenerator(t *testing.T) { + response := NewStreamResponse[int](512) + + go func() { + for i := 0; i < 10000; i++ { + response.Write(i) + time.Sleep(time.Microsecond) + } + response.Close() + }() + + msg := 0 + + for response.Next() { + _, err := response.Read() + if err != nil { + t.Error(err) + } + msg += 1 + } + + if msg != 10000 { + t.Errorf("Expected 10000 messages, got %d", msg) + } +} + +func TestStreamGeneratorErrorMessage(t *testing.T) { + response := NewStreamResponse[int](512) + + go func() { + for i := 0; i < 10000; i++ { + response.Write(i) + time.Sleep(time.Microsecond) + } + response.WriteError(errors.New("test error")) + response.Close() + }() + + for response.Next() { + _, err := response.Read() + if err != nil { + if err.Error() != "test error" { + t.Error(err) + } + } + } +}