From 184fb1efa0a264afa48fa95cc3519ffd81883bdf Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 19 Jul 2024 22:53:59 +0800 Subject: [PATCH] feat: text-embedding and moderation --- internal/core/plugin_daemon/model_service.go | 12 ++-- internal/server/controller.go | 55 +++++++++++++++++++ internal/server/http.go | 5 ++ internal/service/invoke.go | 50 +++++++++++++++++ .../entities/model_entities/moderation.go | 4 ++ .../entities/model_entities/speech2text.go | 4 ++ internal/types/entities/model_entities/tts.go | 4 ++ internal/types/entities/requests/model.go | 14 ++--- 8 files changed, 135 insertions(+), 13 deletions(-) diff --git a/internal/core/plugin_daemon/model_service.go b/internal/core/plugin_daemon/model_service.go index b575ac8..c833842 100644 --- a/internal/core/plugin_daemon/model_service.go +++ b/internal/core/plugin_daemon/model_service.go @@ -143,9 +143,9 @@ func InvokeTTS( session *session_manager.Session, request *requests.RequestInvokeTTS, ) ( - *stream.StreamResponse[string], error, + *stream.StreamResponse[model_entities.TTSResult], error, ) { - return genericInvokePlugin[requests.RequestInvokeTTS, string]( + return genericInvokePlugin[requests.RequestInvokeTTS, model_entities.TTSResult]( session, request, 1, @@ -158,9 +158,9 @@ func InvokeSpeech2Text( session *session_manager.Session, request *requests.RequestInvokeSpeech2Text, ) ( - *stream.StreamResponse[string], error, + *stream.StreamResponse[model_entities.Speech2TextResult], error, ) { - return genericInvokePlugin[requests.RequestInvokeSpeech2Text, string]( + return genericInvokePlugin[requests.RequestInvokeSpeech2Text, model_entities.Speech2TextResult]( session, request, 1, @@ -173,9 +173,9 @@ func InvokeModeration( session *session_manager.Session, request *requests.RequestInvokeModeration, ) ( - *stream.StreamResponse[bool], error, + *stream.StreamResponse[model_entities.ModerationResult], error, ) { - return genericInvokePlugin[requests.RequestInvokeModeration, bool]( + return genericInvokePlugin[requests.RequestInvokeModeration, model_entities.ModerationResult]( session, request, 1, diff --git a/internal/server/controller.go b/internal/server/controller.go index 57749f8..f3ca30d 100644 --- a/internal/server/controller.go +++ b/internal/server/controller.go @@ -32,3 +32,58 @@ func InvokeLLM(c *gin.Context) { }, ) } + +func InvokeTextEmbedding(c *gin.Context) { + type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding] + + BindRequest[request]( + c, + func(itr request) { + service.InvokeTextEmbedding(&itr, c) + }, + ) +} + +func InvokeRerank(c *gin.Context) { + type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank] + + BindRequest[request]( + c, + func(itr request) { + service.InvokeRerank(&itr, c) + }, + ) +} + +func InvokeTTS(c *gin.Context) { + type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS] + + BindRequest[request]( + c, + func(itr request) { + service.InvokeTTS(&itr, c) + }, + ) +} + +func InvokeSpeech2Text(c *gin.Context) { + type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text] + + BindRequest[request]( + c, + func(itr request) { + service.InvokeSpeech2Text(&itr, c) + }, + ) +} + +func InvokeModeration(c *gin.Context) { + type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration] + + BindRequest[request]( + c, + func(itr request) { + service.InvokeModeration(&itr, c) + }, + ) +} diff --git a/internal/server/http.go b/internal/server/http.go index 458ae10..f1669ad 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -13,6 +13,11 @@ func server(config *app.Config) { engine.GET("/health/check", HealthCheck) engine.POST("/plugin/tool/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTool) engine.POST("/plugin/llm/invoke", CheckingKey(config.PluginInnerApiKey), InvokeLLM) + engine.POST("/plugin/text_embedding/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTextEmbedding) + engine.POST("/plugin/rerank/invoke", CheckingKey(config.PluginInnerApiKey), InvokeRerank) + engine.POST("/plugin/tts/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTTS) + engine.POST("/plugin/speech2text/invoke", CheckingKey(config.PluginInnerApiKey), InvokeSpeech2Text) + engine.POST("/plugin/moderation/invoke", CheckingKey(config.PluginInnerApiKey), InvokeModeration) engine.Run(fmt.Sprintf(":%d", config.SERVER_PORT)) } diff --git a/internal/service/invoke.go b/internal/service/invoke.go index 0a31dc2..6497d37 100644 --- a/internal/service/invoke.go +++ b/internal/service/invoke.go @@ -30,3 +30,53 @@ func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM] return plugin_daemon.InvokeLLM(session, &r.Data) }, ctx) } + +func InvokeTextEmbedding(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding], ctx *gin.Context) { + // create session + session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + defer session.Close() + + baseSSEService(r, func() (*stream.StreamResponse[model_entities.TextEmbeddingResult], error) { + return plugin_daemon.InvokeTextEmbedding(session, &r.Data) + }, ctx) +} + +func InvokeRerank(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank], ctx *gin.Context) { + // create session + session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + defer session.Close() + + baseSSEService(r, func() (*stream.StreamResponse[model_entities.RerankResult], error) { + return plugin_daemon.InvokeRerank(session, &r.Data) + }, ctx) +} + +func InvokeTTS(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS], ctx *gin.Context) { + // create session + session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + defer session.Close() + + baseSSEService(r, func() (*stream.StreamResponse[model_entities.TTSResult], error) { + return plugin_daemon.InvokeTTS(session, &r.Data) + }, ctx) +} + +func InvokeSpeech2Text(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text], ctx *gin.Context) { + // create session + session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + defer session.Close() + + baseSSEService(r, func() (*stream.StreamResponse[model_entities.Speech2TextResult], error) { + return plugin_daemon.InvokeSpeech2Text(session, &r.Data) + }, ctx) +} + +func InvokeModeration(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration], ctx *gin.Context) { + // create session + session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + defer session.Close() + + baseSSEService(r, func() (*stream.StreamResponse[model_entities.ModerationResult], error) { + return plugin_daemon.InvokeModeration(session, &r.Data) + }, ctx) +} diff --git a/internal/types/entities/model_entities/moderation.go b/internal/types/entities/model_entities/moderation.go index 71ec591..8c10428 100644 --- a/internal/types/entities/model_entities/moderation.go +++ b/internal/types/entities/model_entities/moderation.go @@ -1 +1,5 @@ package model_entities + +type ModerationResult struct { + Result bool `json:"result"` +} diff --git a/internal/types/entities/model_entities/speech2text.go b/internal/types/entities/model_entities/speech2text.go index 71ec591..4747307 100644 --- a/internal/types/entities/model_entities/speech2text.go +++ b/internal/types/entities/model_entities/speech2text.go @@ -1 +1,5 @@ package model_entities + +type Speech2TextResult struct { + Result string `json:"result"` +} diff --git a/internal/types/entities/model_entities/tts.go b/internal/types/entities/model_entities/tts.go index 71ec591..979b51c 100644 --- a/internal/types/entities/model_entities/tts.go +++ b/internal/types/entities/model_entities/tts.go @@ -1 +1,5 @@ package model_entities + +type TTSResult struct { + Result string `json:"result"` +} diff --git a/internal/types/entities/requests/model.go b/internal/types/entities/requests/model.go index fae95e6..045587f 100644 --- a/internal/types/entities/requests/model.go +++ b/internal/types/entities/requests/model.go @@ -6,7 +6,7 @@ import ( type BaseRequestInvokeModel struct { Provider string `json:"provider" validate:"required"` - ModelType model_entities.ModelType `json:"model_type" mapstructure:"model_type" validate:"required,model_type"` + ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type"` Model string `json:"model" validate:"required"` Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"` } @@ -23,11 +23,11 @@ func (r *BaseRequestInvokeModel) ToCallerArguments() map[string]any { type RequestInvokeLLM struct { BaseRequestInvokeModel - ModelParameters map[string]any `json:"model_parameters" mapstructure:"model_parameters" validate:"omitempty,dive,is_basic_type"` - PromptMessages []model_entities.PromptMessage `json:"prompt_messages" mapstructure:"prompt_messages" validate:"omitempty,dive"` + ModelParameters map[string]any `json:"model_parameters" validate:"omitempty,dive,is_basic_type"` + PromptMessages []model_entities.PromptMessage `json:"prompt_messages" validate:"omitempty,dive"` Tools []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"` Stop []string `json:"stop" validate:"omitempty"` - Stream bool `json:"stream" mapstructure:"stream"` + Stream bool `json:"stream" ` } type RequestInvokeTextEmbedding struct { @@ -41,14 +41,14 @@ type RequestInvokeRerank struct { Query string `json:"query" validate:"required"` Docs []string `json:"docs" validate:"required,dive"` - ScoreThreshold float64 `json:"score_threshold" mapstructure:"score_threshold"` - TopN int `json:"top_n" mapstructure:"top_n"` + ScoreThreshold float64 `json:"score_threshold" ` + TopN int `json:"top_n" ` } type RequestInvokeTTS struct { BaseRequestInvokeModel - ContentText string `json:"content_text" mapstructure:"content_text" validate:"required"` + ContentText string `json:"content_text" validate:"required"` Voice string `json:"voice" validate:"required"` }