From cbd3189e06e99abff8216f94d426863741624d15 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 19 Jul 2024 21:03:26 +0800 Subject: [PATCH] refactor: validator --- .../dify_invocation/workflow_node_data.go | 9 +-- internal/core/plugin_daemon/invoke_dify.go | 9 --- internal/types/entities/model_entities/llm.go | 59 ------------------- .../types/entities/model_entities/llm_test.go | 43 +++++--------- .../entities/model_entities/moderation.go | 1 + .../types/entities/model_entities/rerank.go | 12 ++++ .../entities/model_entities/speech2text.go | 1 + .../entities/model_entities/text_embedding.go | 1 + internal/types/entities/model_entities/tts.go | 1 + .../plugin_entities/model_configuration.go | 17 ------ .../model_configuration_test.go | 3 +- .../plugin_entities/tool_configuration.go | 19 ------ internal/types/entities/requests/model.go | 21 ------- internal/types/entities/requests/tool.go | 24 -------- internal/utils/parser/json.go | 14 ++++- internal/utils/parser/yaml.go | 16 ++++- 16 files changed, 59 insertions(+), 191 deletions(-) create mode 100644 internal/types/entities/model_entities/moderation.go create mode 100644 internal/types/entities/model_entities/rerank.go create mode 100644 internal/types/entities/model_entities/speech2text.go create mode 100644 internal/types/entities/model_entities/text_embedding.go create mode 100644 internal/types/entities/model_entities/tts.go diff --git a/internal/core/dify_invocation/workflow_node_data.go b/internal/core/dify_invocation/workflow_node_data.go index 884bd3f..95f514d 100644 --- a/internal/core/dify_invocation/workflow_node_data.go +++ b/internal/core/dify_invocation/workflow_node_data.go @@ -3,7 +3,7 @@ package dify_invocation type WorkflowNodeData interface { FromMap(map[string]any) error - *KnowledgeRetrievalNodeData | *QuestionClassifierNodeData | *ParameterExtractorNodeData | *CodeNodeData + *KnowledgeRetrievalNodeData | *QuestionClassifierNodeData | *ParameterExtractorNodeData } type NodeType string @@ -35,10 +35,3 @@ type ParameterExtractorNodeData struct { func (r *ParameterExtractorNodeData) FromMap(data map[string]any) error { return nil } - -type CodeNodeData struct { -} - -func (r *CodeNodeData) FromMap(data map[string]any) error { - return nil -} diff --git a/internal/core/plugin_daemon/invoke_dify.go b/internal/core/plugin_daemon/invoke_dify.go index 81ab586..eb3e9f0 100644 --- a/internal/core/plugin_daemon/invoke_dify.go +++ b/internal/core/plugin_daemon/invoke_dify.go @@ -89,15 +89,6 @@ func invokeDify( return fmt.Errorf("unmarshal parameter extractor node data failed: %s", err.Error()) } submitNodeInvocationRequestTask(runtime, session, request_id, &d) - case dify_invocation.CODE: - d := dify_invocation.InvokeNodeRequest[*dify_invocation.CodeNodeData]{ - NodeType: dify_invocation.CODE, - NodeData: &dify_invocation.CodeNodeData{}, - } - if err := d.FromMap(node_data); err != nil { - return fmt.Errorf("unmarshal code node data failed: %s", err.Error()) - } - submitNodeInvocationRequestTask(runtime, session, request_id, &d) default: return fmt.Errorf("unknown node type: %s", node_type) } diff --git a/internal/types/entities/model_entities/llm.go b/internal/types/entities/model_entities/llm.go index 13ad677..8e40975 100644 --- a/internal/types/entities/model_entities/llm.go +++ b/internal/types/entities/model_entities/llm.go @@ -156,11 +156,6 @@ func (p *PromptMessage) UnmarshalJSON(data []byte) error { } } - // Validate the struct - if err := validators.GlobalEntitiesValidator.Struct(p); err != nil { - return err - } - // validate tool call id if p.Role == PROMPT_MESSAGE_ROLE_TOOL && p.ToolCallId == "" { return errors.New("tool call id is required") @@ -175,24 +170,6 @@ type PromptMessageTool struct { Parameters map[string]any `json:"parameters"` } -func (p *PromptMessageTool) UnmarshalJSON(data []byte) error { - type Alias PromptMessageTool - aux := &struct { - *Alias - }{ - Alias: (*Alias)(p), - } - if err := json.Unmarshal(data, &aux); err != nil { - return err - } - - if err := validators.GlobalEntitiesValidator.Struct(p); err != nil { - return err - } - - return nil -} - type LLMResultChunk struct { Model LLMModel `json:"model" validate:"required"` PromptMessages []PromptMessage `json:"prompt_messages" validate:"required,dive"` @@ -200,24 +177,6 @@ type LLMResultChunk struct { Delta LLMResultChunkDelta `json:"delta" validate:"required"` } -func (l *LLMResultChunk) UnmarshalJSON(data []byte) error { - type Alias LLMResultChunk - aux := &struct { - *Alias - }{ - Alias: (*Alias)(l), - } - if err := json.Unmarshal(data, &aux); err != nil { - return err - } - - if err := validators.GlobalEntitiesValidator.Struct(l); err != nil { - return err - } - - return nil -} - type LLMUsage struct { PromptTokens *int `json:"prompt_tokens" validate:"required"` PromptUnitPrice decimal.Decimal `json:"prompt_unit_price" validate:"required"` @@ -233,24 +192,6 @@ type LLMUsage struct { Latency *float64 `json:"latency" validate:"required"` } -func (l *LLMUsage) UnmarshalJSON(data []byte) error { - type Alias LLMUsage - aux := &struct { - *Alias - }{ - Alias: (*Alias)(l), - } - if err := json.Unmarshal(data, &aux); err != nil { - return err - } - - if err := validators.GlobalEntitiesValidator.Struct(l); err != nil { - return err - } - - return nil -} - type LLMResultChunkDelta struct { Index *int `json:"index" validate:"required"` Message PromptMessage `json:"message" validate:"required"` diff --git a/internal/types/entities/model_entities/llm_test.go b/internal/types/entities/model_entities/llm_test.go index 78fec14..a2206c4 100644 --- a/internal/types/entities/model_entities/llm_test.go +++ b/internal/types/entities/model_entities/llm_test.go @@ -1,8 +1,9 @@ package model_entities import ( - "encoding/json" "testing" + + "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" ) func TestFullFunctionPromptMessage(t *testing.T) { @@ -42,9 +43,7 @@ func TestFullFunctionPromptMessage(t *testing.T) { ` ) - var prompt_message PromptMessage - - err := json.Unmarshal([]byte(system_message), &prompt_message) + prompt_message, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(system_message)) if err != nil { t.Error(err) } @@ -52,7 +51,7 @@ func TestFullFunctionPromptMessage(t *testing.T) { t.Error("role is not system") } - err = json.Unmarshal([]byte(user_message), &prompt_message) + prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(user_message)) if err != nil { t.Error(err) } @@ -60,7 +59,7 @@ func TestFullFunctionPromptMessage(t *testing.T) { t.Error("role is not user") } - err = json.Unmarshal([]byte(assistant_message), &prompt_message) + prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(assistant_message)) if err != nil { t.Error(err) } @@ -68,7 +67,7 @@ func TestFullFunctionPromptMessage(t *testing.T) { t.Error("role is not assistant") } - err = json.Unmarshal([]byte(image_message), &prompt_message) + prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(image_message)) if err != nil { t.Error(err) } @@ -79,7 +78,7 @@ func TestFullFunctionPromptMessage(t *testing.T) { t.Error("type is not image") } - err = json.Unmarshal([]byte(tool_message), &prompt_message) + prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(tool_message)) if err != nil { t.Error(err) } @@ -101,9 +100,7 @@ func TestWrongRole(t *testing.T) { ` ) - var prompt_message PromptMessage - - err := json.Unmarshal([]byte(wrong_role), &prompt_message) + _, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_role)) if err == nil { t.Error("error is nil") } @@ -119,9 +116,7 @@ func TestWrongContent(t *testing.T) { ` ) - var prompt_message PromptMessage - - err := json.Unmarshal([]byte(wrong_content), &prompt_message) + _, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content)) if err == nil { t.Error("error is nil") } @@ -142,9 +137,7 @@ func TestWrongContentArray(t *testing.T) { ` ) - var prompt_message PromptMessage - - err := json.Unmarshal([]byte(wrong_content_array), &prompt_message) + _, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content_array)) if err == nil { t.Error("error is nil") } @@ -164,9 +157,7 @@ func TestWrongContentArray2(t *testing.T) { ` ) - var prompt_message PromptMessage - - err := json.Unmarshal([]byte(wrong_content_array2), &prompt_message) + _, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content_array2)) if err == nil { t.Error("error is nil") } @@ -191,9 +182,7 @@ func TestWrongContentArray3(t *testing.T) { ` ) - var prompt_message PromptMessage - - err := json.Unmarshal([]byte(wrong_content_array3), &prompt_message) + _, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content_array3)) if err == nil { t.Error("error is nil") } @@ -241,9 +230,7 @@ func TestFullFunctionLLMResultChunk(t *testing.T) { ` ) - var c LLMResultChunk - - err := json.Unmarshal([]byte(llm_result_chunk), &c) + _, err := parser.UnmarshalJsonBytes[LLMResultChunk]([]byte(llm_result_chunk)) if err != nil { t.Error(err) } @@ -269,9 +256,7 @@ func TestZeroLLMUsage(t *testing.T) { ` ) - var u LLMUsage - - err := json.Unmarshal([]byte(llm_usage), &u) + _, err := parser.UnmarshalJsonBytes[LLMUsage]([]byte(llm_usage)) if err != nil { t.Error(err) } diff --git a/internal/types/entities/model_entities/moderation.go b/internal/types/entities/model_entities/moderation.go new file mode 100644 index 0000000..71ec591 --- /dev/null +++ b/internal/types/entities/model_entities/moderation.go @@ -0,0 +1 @@ +package model_entities diff --git a/internal/types/entities/model_entities/rerank.go b/internal/types/entities/model_entities/rerank.go new file mode 100644 index 0000000..955d5a8 --- /dev/null +++ b/internal/types/entities/model_entities/rerank.go @@ -0,0 +1,12 @@ +package model_entities + +type RerankDocument struct { + Index *int `json:"index" validate:"required"` + Text *string `json:"text" validate:"required"` + Score *float64 `json:"score" validate:"required"` +} + +type RerankResult struct { + Model string `json:"model" validate:"required"` + Docs []RerankDocument `json:"docs" validate:"required,dive"` +} diff --git a/internal/types/entities/model_entities/speech2text.go b/internal/types/entities/model_entities/speech2text.go new file mode 100644 index 0000000..71ec591 --- /dev/null +++ b/internal/types/entities/model_entities/speech2text.go @@ -0,0 +1 @@ +package model_entities diff --git a/internal/types/entities/model_entities/text_embedding.go b/internal/types/entities/model_entities/text_embedding.go new file mode 100644 index 0000000..71ec591 --- /dev/null +++ b/internal/types/entities/model_entities/text_embedding.go @@ -0,0 +1 @@ +package model_entities diff --git a/internal/types/entities/model_entities/tts.go b/internal/types/entities/model_entities/tts.go new file mode 100644 index 0000000..71ec591 --- /dev/null +++ b/internal/types/entities/model_entities/tts.go @@ -0,0 +1 @@ +package model_entities diff --git a/internal/types/entities/plugin_entities/model_configuration.go b/internal/types/entities/plugin_entities/model_configuration.go index f4179ce..93dfc64 100644 --- a/internal/types/entities/plugin_entities/model_configuration.go +++ b/internal/types/entities/plugin_entities/model_configuration.go @@ -1,8 +1,6 @@ package plugin_entities import ( - "encoding/json" - "github.com/go-playground/locales/en" ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" @@ -278,18 +276,3 @@ func init() { validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType) } - -func UnmarshalModelProviderConfiguration(data []byte) (*ModelProviderConfiguration, error) { - var modelProviderConfiguration ModelProviderConfiguration - err := json.Unmarshal(data, &modelProviderConfiguration) - if err != nil { - return nil, err - } - - err = validators.GlobalEntitiesValidator.Struct(modelProviderConfiguration) - if err != nil { - return nil, err - } - - return &modelProviderConfiguration, nil -} diff --git a/internal/types/entities/plugin_entities/model_configuration_test.go b/internal/types/entities/plugin_entities/model_configuration_test.go index 03ee786..9f609aa 100644 --- a/internal/types/entities/plugin_entities/model_configuration_test.go +++ b/internal/types/entities/plugin_entities/model_configuration_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "testing" + "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" "gopkg.in/yaml.v3" ) @@ -156,7 +157,7 @@ func TestFullFunctionModelProvider_Validate(t *testing.T) { t.Error(err) } - _, err = UnmarshalModelProviderConfiguration(json_data) + _, err = parser.UnmarshalJsonBytes[ModelProviderConfiguration](json_data) if err != nil { t.Errorf("UnmarshalModelProviderConfiguration() error = %v", err) } diff --git a/internal/types/entities/plugin_entities/tool_configuration.go b/internal/types/entities/plugin_entities/tool_configuration.go index 4ae549b..b73bea5 100644 --- a/internal/types/entities/plugin_entities/tool_configuration.go +++ b/internal/types/entities/plugin_entities/tool_configuration.go @@ -1,7 +1,6 @@ package plugin_entities import ( - "encoding/json" "fmt" "github.com/go-playground/locales/en" @@ -253,24 +252,6 @@ func init() { validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType) } -func (t *ToolProviderConfiguration) UnmarshalJSON(data []byte) error { - type Alias ToolProviderConfiguration - aux := &struct { - *Alias - }{ - Alias: (*Alias)(t), - } - if err := json.Unmarshal(data, &aux); err != nil { - return err - } - - if err := validators.GlobalEntitiesValidator.Struct(t); err != nil { - return err - } - - return nil -} - func UnmarshalToolProviderConfiguration(data []byte) (*ToolProviderConfiguration, error) { obj, err := parser.UnmarshalJsonBytes[ToolProviderConfiguration](data) if err != nil { diff --git a/internal/types/entities/requests/model.go b/internal/types/entities/requests/model.go index 6643501..f5b2eee 100644 --- a/internal/types/entities/requests/model.go +++ b/internal/types/entities/requests/model.go @@ -1,10 +1,7 @@ package requests import ( - "encoding/json" - "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities" - "github.com/langgenius/dify-plugin-daemon/internal/types/validators" ) type RequestInvokeLLM struct { @@ -18,21 +15,3 @@ type RequestInvokeLLM struct { Stream bool `json:"stream"` Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"` } - -func (r *RequestInvokeLLM) UnmarshalJSON(data []byte) error { - type Alias RequestInvokeLLM - aux := &struct { - *Alias - }{ - Alias: (*Alias)(r), - } - if err := json.Unmarshal(data, &aux); err != nil { - return err - } - - if err := validators.GlobalEntitiesValidator.Struct(r); err != nil { - return err - } - - return nil -} diff --git a/internal/types/entities/requests/tool.go b/internal/types/entities/requests/tool.go index 43d47d5..8b9b1d3 100644 --- a/internal/types/entities/requests/tool.go +++ b/internal/types/entities/requests/tool.go @@ -1,32 +1,8 @@ package requests -import ( - "encoding/json" - - "github.com/langgenius/dify-plugin-daemon/internal/types/validators" -) - type RequestInvokeTool struct { Provider string `json:"provider" validate:"required"` Tool string `json:"tool" validate:"required"` ToolParameters map[string]any `json:"tool_parameters" validate:"omitempty,dive,is_basic_type"` Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"` } - -func (r *RequestInvokeTool) UnmarshalJSON(data []byte) error { - type Alias RequestInvokeTool - aux := &struct { - *Alias - }{ - Alias: (*Alias)(r), - } - if err := json.Unmarshal(data, &aux); err != nil { - return err - } - - if err := validators.GlobalEntitiesValidator.Struct(r); err != nil { - return err - } - - return nil -} diff --git a/internal/utils/parser/json.go b/internal/utils/parser/json.go index 2bd88c1..789ec13 100644 --- a/internal/utils/parser/json.go +++ b/internal/utils/parser/json.go @@ -1,6 +1,10 @@ package parser -import "encoding/json" +import ( + "encoding/json" + + "github.com/langgenius/dify-plugin-daemon/internal/types/validators" +) func UnmarshalJson[T any](text string) (T, error) { return UnmarshalJsonBytes[T]([]byte(text)) @@ -9,6 +13,14 @@ func UnmarshalJson[T any](text string) (T, error) { func UnmarshalJsonBytes[T any](data []byte) (T, error) { var result T err := json.Unmarshal(data, &result) + if err != nil { + return result, err + } + + if err := validators.GlobalEntitiesValidator.Struct(&result); err != nil { + return result, err + } + return result, err } diff --git a/internal/utils/parser/yaml.go b/internal/utils/parser/yaml.go index 75e6300..f041000 100644 --- a/internal/utils/parser/yaml.go +++ b/internal/utils/parser/yaml.go @@ -1,16 +1,26 @@ package parser import ( + "github.com/go-playground/validator/v10" "gopkg.in/yaml.v3" ) -func UnmarshalYaml[T any](text string) (T, error) { - return UnmarshalYamlBytes[T]([]byte(text)) +func UnmarshalYaml[T any](text string, validator ...validator.Validate) (T, error) { + return UnmarshalYamlBytes[T]([]byte(text), validator...) } -func UnmarshalYamlBytes[T any](data []byte) (T, error) { +func UnmarshalYamlBytes[T any](data []byte, validator ...validator.Validate) (T, error) { var result T err := yaml.Unmarshal(data, &result) + if err != nil { + return result, err + } + + if len(validator) > 0 { + if err := validator[0].Struct(result); err != nil { + return result, err + } + } return result, err }