Skip to content

Commit

Permalink
feat: text-embedding and moderation
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Jul 19, 2024
1 parent 5b96e61 commit 184fb1e
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 13 deletions.
12 changes: 6 additions & 6 deletions internal/core/plugin_daemon/model_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ func InvokeTTS(
session *session_manager.Session,
request *requests.RequestInvokeTTS,
) (
*stream.StreamResponse[string], error,
*stream.StreamResponse[model_entities.TTSResult], error,
) {
return genericInvokePlugin[requests.RequestInvokeTTS, string](
return genericInvokePlugin[requests.RequestInvokeTTS, model_entities.TTSResult](
session,
request,
1,
Expand All @@ -158,9 +158,9 @@ func InvokeSpeech2Text(
session *session_manager.Session,
request *requests.RequestInvokeSpeech2Text,
) (
*stream.StreamResponse[string], error,
*stream.StreamResponse[model_entities.Speech2TextResult], error,
) {
return genericInvokePlugin[requests.RequestInvokeSpeech2Text, string](
return genericInvokePlugin[requests.RequestInvokeSpeech2Text, model_entities.Speech2TextResult](
session,
request,
1,
Expand All @@ -173,9 +173,9 @@ func InvokeModeration(
session *session_manager.Session,
request *requests.RequestInvokeModeration,
) (
*stream.StreamResponse[bool], error,
*stream.StreamResponse[model_entities.ModerationResult], error,
) {
return genericInvokePlugin[requests.RequestInvokeModeration, bool](
return genericInvokePlugin[requests.RequestInvokeModeration, model_entities.ModerationResult](
session,
request,
1,
Expand Down
55 changes: 55 additions & 0 deletions internal/server/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,58 @@ func InvokeLLM(c *gin.Context) {
},
)
}

func InvokeTextEmbedding(c *gin.Context) {
type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding]

BindRequest[request](
c,
func(itr request) {
service.InvokeTextEmbedding(&itr, c)
},
)
}

func InvokeRerank(c *gin.Context) {
type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank]

BindRequest[request](
c,
func(itr request) {
service.InvokeRerank(&itr, c)
},
)
}

func InvokeTTS(c *gin.Context) {
type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS]

BindRequest[request](
c,
func(itr request) {
service.InvokeTTS(&itr, c)
},
)
}

func InvokeSpeech2Text(c *gin.Context) {
type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text]

BindRequest[request](
c,
func(itr request) {
service.InvokeSpeech2Text(&itr, c)
},
)
}

func InvokeModeration(c *gin.Context) {
type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration]

BindRequest[request](
c,
func(itr request) {
service.InvokeModeration(&itr, c)
},
)
}
5 changes: 5 additions & 0 deletions internal/server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ func server(config *app.Config) {
engine.GET("/health/check", HealthCheck)
engine.POST("/plugin/tool/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTool)
engine.POST("/plugin/llm/invoke", CheckingKey(config.PluginInnerApiKey), InvokeLLM)
engine.POST("/plugin/text_embedding/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTextEmbedding)
engine.POST("/plugin/rerank/invoke", CheckingKey(config.PluginInnerApiKey), InvokeRerank)
engine.POST("/plugin/tts/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTTS)
engine.POST("/plugin/speech2text/invoke", CheckingKey(config.PluginInnerApiKey), InvokeSpeech2Text)
engine.POST("/plugin/moderation/invoke", CheckingKey(config.PluginInnerApiKey), InvokeModeration)

engine.Run(fmt.Sprintf(":%d", config.SERVER_PORT))
}
50 changes: 50 additions & 0 deletions internal/service/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,53 @@ func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM]
return plugin_daemon.InvokeLLM(session, &r.Data)
}, ctx)
}

func InvokeTextEmbedding(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding], ctx *gin.Context) {
// create session
session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
defer session.Close()

baseSSEService(r, func() (*stream.StreamResponse[model_entities.TextEmbeddingResult], error) {
return plugin_daemon.InvokeTextEmbedding(session, &r.Data)
}, ctx)
}

func InvokeRerank(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank], ctx *gin.Context) {
// create session
session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
defer session.Close()

baseSSEService(r, func() (*stream.StreamResponse[model_entities.RerankResult], error) {
return plugin_daemon.InvokeRerank(session, &r.Data)
}, ctx)
}

func InvokeTTS(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS], ctx *gin.Context) {
// create session
session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
defer session.Close()

baseSSEService(r, func() (*stream.StreamResponse[model_entities.TTSResult], error) {
return plugin_daemon.InvokeTTS(session, &r.Data)
}, ctx)
}

func InvokeSpeech2Text(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text], ctx *gin.Context) {
// create session
session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
defer session.Close()

baseSSEService(r, func() (*stream.StreamResponse[model_entities.Speech2TextResult], error) {
return plugin_daemon.InvokeSpeech2Text(session, &r.Data)
}, ctx)
}

func InvokeModeration(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration], ctx *gin.Context) {
// create session
session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
defer session.Close()

baseSSEService(r, func() (*stream.StreamResponse[model_entities.ModerationResult], error) {
return plugin_daemon.InvokeModeration(session, &r.Data)
}, ctx)
}
4 changes: 4 additions & 0 deletions internal/types/entities/model_entities/moderation.go
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
package model_entities

type ModerationResult struct {
Result bool `json:"result"`
}
4 changes: 4 additions & 0 deletions internal/types/entities/model_entities/speech2text.go
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
package model_entities

type Speech2TextResult struct {
Result string `json:"result"`
}
4 changes: 4 additions & 0 deletions internal/types/entities/model_entities/tts.go
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
package model_entities

type TTSResult struct {
Result string `json:"result"`
}
14 changes: 7 additions & 7 deletions internal/types/entities/requests/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

type BaseRequestInvokeModel struct {
Provider string `json:"provider" validate:"required"`
ModelType model_entities.ModelType `json:"model_type" mapstructure:"model_type" validate:"required,model_type"`
ModelType model_entities.ModelType `json:"model_type" validate:"required,model_type"`
Model string `json:"model" validate:"required"`
Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
}
Expand All @@ -23,11 +23,11 @@ func (r *BaseRequestInvokeModel) ToCallerArguments() map[string]any {
type RequestInvokeLLM struct {
BaseRequestInvokeModel

ModelParameters map[string]any `json:"model_parameters" mapstructure:"model_parameters" validate:"omitempty,dive,is_basic_type"`
PromptMessages []model_entities.PromptMessage `json:"prompt_messages" mapstructure:"prompt_messages" validate:"omitempty,dive"`
ModelParameters map[string]any `json:"model_parameters" validate:"omitempty,dive,is_basic_type"`
PromptMessages []model_entities.PromptMessage `json:"prompt_messages" validate:"omitempty,dive"`
Tools []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"`
Stop []string `json:"stop" validate:"omitempty"`
Stream bool `json:"stream" mapstructure:"stream"`
Stream bool `json:"stream" `
}

type RequestInvokeTextEmbedding struct {
Expand All @@ -41,14 +41,14 @@ type RequestInvokeRerank struct {

Query string `json:"query" validate:"required"`
Docs []string `json:"docs" validate:"required,dive"`
ScoreThreshold float64 `json:"score_threshold" mapstructure:"score_threshold"`
TopN int `json:"top_n" mapstructure:"top_n"`
ScoreThreshold float64 `json:"score_threshold" `
TopN int `json:"top_n" `
}

type RequestInvokeTTS struct {
BaseRequestInvokeModel

ContentText string `json:"content_text" mapstructure:"content_text" validate:"required"`
ContentText string `json:"content_text" validate:"required"`
Voice string `json:"voice" validate:"required"`
}

Expand Down

0 comments on commit 184fb1e

Please sign in to comment.