diff --git a/internal/core/dify_invocation/types.go b/internal/core/dify_invocation/types.go index 454b8e5..f6a8cea 100644 --- a/internal/core/dify_invocation/types.go +++ b/internal/core/dify_invocation/types.go @@ -2,7 +2,6 @@ package dify_invocation import ( "encoding/json" - "fmt" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities" ) @@ -13,23 +12,6 @@ type BaseInvokeDifyRequest struct { Type InvokeType `json:"type"` } -func (r *BaseInvokeDifyRequest) FromMap(data map[string]any) error { - var ok bool - if r.TenantId, ok = data["tenant_id"].(string); !ok { - return fmt.Errorf("tenant_id is not a string") - } - - if r.UserId, ok = data["user_id"].(string); !ok { - return fmt.Errorf("user_id is not a string") - } - - if r.Type, ok = data["type"].(InvokeType); !ok { - return fmt.Errorf("type is not a string") - } - - return nil -} - type InvokeType string const ( @@ -46,27 +28,6 @@ type InvokeModelRequest struct { Parameters map[string]any `json:"parameters"` } -func (r *InvokeModelRequest) FromMap(base map[string]any, data map[string]any) error { - var ok bool - if r.Provider, ok = data["provider"].(string); !ok { - return fmt.Errorf("provider is not a string") - } - - if r.Model, ok = data["model"].(string); !ok { - return fmt.Errorf("model is not a string") - } - - if r.ModelType, ok = data["model_type"].(model_entities.ModelType); !ok { - return fmt.Errorf("model_type is not a string") - } - - if r.Parameters, ok = data["parameters"].(map[string]any); !ok { - return fmt.Errorf("parameters is not a map") - } - - return nil -} - func (r InvokeModelRequest) MarshalJSON() ([]byte, error) { flattened := make(map[string]any) flattened["tenant_id"] = r.TenantId @@ -87,23 +48,6 @@ type InvokeToolRequest struct { Parameters map[string]any `json:"parameters"` } -func (r *InvokeToolRequest) FromMap(base map[string]any, data map[string]any) error { - var ok bool - if r.Provider, ok = data["provider"].(string); !ok { - return fmt.Errorf("provider is not a string") - } - - if r.Tool, ok = data["tool"].(string); !ok { - return fmt.Errorf("tool is not a string") - } - - if r.Parameters, ok = data["parameters"].(map[string]any); !ok { - return fmt.Errorf("parameters is not a map") - } - - return nil -} - func (r InvokeToolRequest) MarshalJSON() ([]byte, error) { flattened := make(map[string]any) flattened["tenant_id"] = r.TenantId @@ -123,19 +67,6 @@ type InvokeNodeRequest[T WorkflowNodeData] struct { NodeData T `json:"node_data"` } -func (r *InvokeNodeRequest[T]) FromMap(data map[string]any) error { - var ok bool - if r.NodeType, ok = data["node_type"].(NodeType); !ok { - return fmt.Errorf("node_type is not a string") - } - - if err := r.NodeData.FromMap(data["node_data"].(map[string]any)); err != nil { - return err - } - - return nil -} - func (r InvokeNodeRequest[T]) MarshalJSON() ([]byte, error) { flattened := make(map[string]any) flattened["tenant_id"] = r.TenantId diff --git a/internal/core/dify_invocation/workflow_node_data.go b/internal/core/dify_invocation/workflow_node_data.go index 95f514d..11f078b 100644 --- a/internal/core/dify_invocation/workflow_node_data.go +++ b/internal/core/dify_invocation/workflow_node_data.go @@ -1,9 +1,7 @@ package dify_invocation type WorkflowNodeData interface { - FromMap(map[string]any) error - - *KnowledgeRetrievalNodeData | *QuestionClassifierNodeData | *ParameterExtractorNodeData + KnowledgeRetrievalNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData } type NodeType string @@ -18,20 +16,8 @@ const ( type KnowledgeRetrievalNodeData struct { } -func (r *KnowledgeRetrievalNodeData) FromMap(data map[string]any) error { - return nil -} - type QuestionClassifierNodeData struct { } -func (r *QuestionClassifierNodeData) FromMap(data map[string]any) error { - return nil -} - type ParameterExtractorNodeData struct { } - -func (r *ParameterExtractorNodeData) FromMap(data map[string]any) error { - return nil -} diff --git a/internal/core/plugin_daemon/backwards_invocation/entities.go b/internal/core/plugin_daemon/backwards_invocation/entities.go new file mode 100644 index 0000000..4360743 --- /dev/null +++ b/internal/core/plugin_daemon/backwards_invocation/entities.go @@ -0,0 +1,43 @@ +package backwards_invocation + +type RequestEvent string + +const ( + REQUEST_EVENT_RESPONSE RequestEvent = "response" + REQUEST_EVENT_ERROR RequestEvent = "error" + REQUEST_EVENT_END RequestEvent = "end" +) + +type BaseRequestEvent struct { + BackwardsRequestId string `json:"backwards_request_id"` + Event RequestEvent `json:"event"` + Message string `json:"message"` + Data map[string]any `json:"data"` +} + +func NewResponseEvent(request_id string, message string, data map[string]any) *BaseRequestEvent { + return &BaseRequestEvent{ + BackwardsRequestId: request_id, + Event: REQUEST_EVENT_RESPONSE, + Message: message, + Data: data, + } +} + +func NewErrorEvent(request_id string, message string) *BaseRequestEvent { + return &BaseRequestEvent{ + BackwardsRequestId: request_id, + Event: REQUEST_EVENT_ERROR, + Message: message, + Data: nil, + } +} + +func NewEndEvent(request_id string) *BaseRequestEvent { + return &BaseRequestEvent{ + BackwardsRequestId: request_id, + Event: REQUEST_EVENT_END, + Message: "", + Data: nil, + } +} diff --git a/internal/core/plugin_daemon/backwards_invocation/request.go b/internal/core/plugin_daemon/backwards_invocation/request.go new file mode 100644 index 0000000..bb31e9c --- /dev/null +++ b/internal/core/plugin_daemon/backwards_invocation/request.go @@ -0,0 +1,52 @@ +package backwards_invocation + +import ( + "github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation" + "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" + "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" +) + +type BackwardsInvocationType = dify_invocation.InvokeType + +type BackwardsInvocation struct { + typ BackwardsInvocationType + id string + detailed_request map[string]any + session *session_manager.Session +} + +func NewBackwardsInvocation( + typ BackwardsInvocationType, + id string, session *session_manager.Session, detailed_request map[string]any, +) *BackwardsInvocation { + return &BackwardsInvocation{ + typ: typ, + id: id, + detailed_request: detailed_request, + session: session, + } +} + +func (bi *BackwardsInvocation) GetID() string { + return bi.id +} + +func (bi *BackwardsInvocation) WriteError(err error) { + bi.session.Write(parser.MarshalJsonBytes(NewErrorEvent(bi.id, err.Error()))) +} + +func (bi *BackwardsInvocation) Write(message string, data map[string]any) { + bi.session.Write(parser.MarshalJsonBytes(NewResponseEvent(bi.id, message, data))) +} + +func (bi *BackwardsInvocation) End() { + bi.session.Write(parser.MarshalJsonBytes(NewEndEvent(bi.id))) +} + +func (bi *BackwardsInvocation) Type() BackwardsInvocationType { + return bi.typ +} + +func (bi *BackwardsInvocation) RequestData() map[string]any { + return bi.detailed_request +} diff --git a/internal/core/plugin_daemon/invoke_dify.go b/internal/core/plugin_daemon/invoke_dify.go index eb3e9f0..1a6cae6 100644 --- a/internal/core/plugin_daemon/invoke_dify.go +++ b/internal/core/plugin_daemon/invoke_dify.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation" + "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation" "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" "github.com/langgenius/dify-plugin-daemon/internal/types/entities" "github.com/langgenius/dify-plugin-daemon/internal/utils/log" @@ -13,6 +14,7 @@ import ( func invokeDify( runtime entities.PluginRuntimeInterface, + invoke_from PluginAccessType, session *session_manager.Session, data []byte, ) error { // unmarshal invoke data @@ -22,37 +24,67 @@ func invokeDify( return fmt.Errorf("unmarshal invoke request failed: %s", err.Error()) } + // prepare invocation arguments + request_handle, err := prepareDifyInvocationArguments(session, request) + if err != nil { + return err + } + defer request_handle.End() + + if invoke_from == PLUGIN_ACCESS_TYPE_MODEL { + request_handle.WriteError(fmt.Errorf("you can not invoke dify from %s", invoke_from)) + return nil + } + + // dispatch invocation task + dispatchDifyInvocationTask(request_handle) + + return nil +} + +func prepareDifyInvocationArguments(session *session_manager.Session, request map[string]any) (*backwards_invocation.BackwardsInvocation, error) { typ, ok := request["type"].(string) if !ok { - return fmt.Errorf("invoke request missing type: %s", data) + return nil, fmt.Errorf("invoke request missing type: %s", request) } // get request id - request_id, ok := request["request_id"].(string) + backwards_request_id, ok := request["backwards_request_id"].(string) if !ok { - return fmt.Errorf("invoke request missing request_id: %s", data) + return nil, fmt.Errorf("invoke request missing request_id: %s", request) } // get request detailed_request, ok := request["request"].(map[string]any) if !ok { - return fmt.Errorf("invoke request missing request: %s", data) + return nil, fmt.Errorf("invoke request missing request: %s", request) } - switch typ { - case "tool": - r := dify_invocation.InvokeToolRequest{} - if err := r.FromMap(request, detailed_request); err != nil { - return fmt.Errorf("unmarshal tool invoke request failed: %s", err.Error()) + return backwards_invocation.NewBackwardsInvocation( + backwards_invocation.BackwardsInvocationType(typ), + backwards_request_id, session, detailed_request, + ), nil +} + +func dispatchDifyInvocationTask(handle *backwards_invocation.BackwardsInvocation) { + switch handle.Type() { + case dify_invocation.INVOKE_TYPE_TOOL: + r, err := parser.MapToStruct[dify_invocation.InvokeToolRequest](handle.RequestData()) + if err != nil { + handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error())) + return } - submitToolTask(runtime, session, request_id, &r) - case "model": - r := dify_invocation.InvokeModelRequest{} - if err := r.FromMap(request, detailed_request); err != nil { - return fmt.Errorf("unmarshal model invoke request failed: %s", err.Error()) + + submitToolTask(runtime, session, backwards_request_id, &r) + case dify_invocation.INVOKE_TYPE_MODEL: + r, err := parser.MapToStruct[dify_invocation.InvokeModelRequest](handle.RequestData()) + if err != nil { + handle.WriteError(fmt.Errorf("unmarshal invoke model request failed: %s", err.Error())) + return } - submitModelTask(runtime, session, request_id, &r) - case "node": + + submitModelTask(runtime, session, backwards_request_id, &r) + case dify_invocation.INVOKE_TYPE_NODE: node_type, ok := detailed_request["node_type"].(dify_invocation.NodeType) if !ok { return fmt.Errorf("invoke request missing node_type: %s", data) @@ -63,40 +95,35 @@ func invokeDify( } switch node_type { case dify_invocation.QUESTION_CLASSIFIER: - d := dify_invocation.InvokeNodeRequest[*dify_invocation.QuestionClassifierNodeData]{ + d := dify_invocation.InvokeNodeRequest[dify_invocation.QuestionClassifierNodeData]{ NodeType: dify_invocation.QUESTION_CLASSIFIER, - NodeData: &dify_invocation.QuestionClassifierNodeData{}, } if err := d.FromMap(node_data); err != nil { return fmt.Errorf("unmarshal question classifier node data failed: %s", err.Error()) } - submitNodeInvocationRequestTask(runtime, session, request_id, &d) + submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d) case dify_invocation.KNOWLEDGE_RETRIEVAL: - d := dify_invocation.InvokeNodeRequest[*dify_invocation.KnowledgeRetrievalNodeData]{ + d := dify_invocation.InvokeNodeRequest[dify_invocation.KnowledgeRetrievalNodeData]{ NodeType: dify_invocation.KNOWLEDGE_RETRIEVAL, - NodeData: &dify_invocation.KnowledgeRetrievalNodeData{}, } if err := d.FromMap(node_data); err != nil { return fmt.Errorf("unmarshal knowledge retrieval node data failed: %s", err.Error()) } - submitNodeInvocationRequestTask(runtime, session, request_id, &d) + submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d) case dify_invocation.PARAMETER_EXTRACTOR: - d := dify_invocation.InvokeNodeRequest[*dify_invocation.ParameterExtractorNodeData]{ + d := dify_invocation.InvokeNodeRequest[dify_invocation.ParameterExtractorNodeData]{ NodeType: dify_invocation.PARAMETER_EXTRACTOR, - NodeData: &dify_invocation.ParameterExtractorNodeData{}, } if err := d.FromMap(node_data); err != nil { return fmt.Errorf("unmarshal parameter extractor node data failed: %s", err.Error()) } - submitNodeInvocationRequestTask(runtime, session, request_id, &d) + submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d) default: return fmt.Errorf("unknown node type: %s", node_type) } default: return fmt.Errorf("unknown invoke type: %s", typ) } - - return nil } func setTaskContext(session *session_manager.Session, r *dify_invocation.BaseInvokeDifyRequest) { diff --git a/internal/core/plugin_daemon/model_service.go b/internal/core/plugin_daemon/model_service.go index f7d7582..9cec55d 100644 --- a/internal/core/plugin_daemon/model_service.go +++ b/internal/core/plugin_daemon/model_service.go @@ -46,7 +46,7 @@ func genericInvokePlugin[Req any, Rsp any]( } response.Write(chunk) case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE: - invokeDify(runtime, session, chunk.Data) + invokeDify(runtime, typ, session, chunk.Data) case plugin_entities.SESSION_MESSAGE_TYPE_END: response.Close() case plugin_entities.SESSION_MESSAGE_TYPE_ERROR: diff --git a/internal/core/session_manager/session.go b/internal/core/session_manager/session.go index c178667..e438e91 100644 --- a/internal/core/session_manager/session.go +++ b/internal/core/session_manager/session.go @@ -1,9 +1,11 @@ package session_manager import ( + "errors" "sync" "github.com/google/uuid" + "github.com/langgenius/dify-plugin-daemon/internal/types/entities" ) var ( @@ -12,7 +14,8 @@ var ( ) type Session struct { - id string + id string + runtime entities.PluginRuntimeSessionIOInterface tenant_id string user_id string @@ -71,3 +74,15 @@ func (s *Session) UserID() string { func (s *Session) PluginIdentity() string { return s.plugin_identity } + +func (s *Session) BindRuntime(runtime entities.PluginRuntimeSessionIOInterface) { + s.runtime = runtime +} + +func (s *Session) Write(data []byte) error { + if s.runtime == nil { + return errors.New("runtime not bound") + } + s.runtime.Write(s.id, data) + return nil +} diff --git a/internal/service/invoke_model.go b/internal/service/invoke_model.go index c777ee1..99095a2 100644 --- a/internal/service/invoke_model.go +++ b/internal/service/invoke_model.go @@ -3,17 +3,15 @@ package service import ( "github.com/gin-gonic/gin" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon" - "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities" - "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" "github.com/langgenius/dify-plugin-daemon/internal/utils/stream" ) func InvokeTool(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTool], ctx *gin.Context) { // create session - session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + session := createSession(r) defer session.Close() baseSSEService(r, func() (*stream.StreamResponse[tool_entities.ToolResponseChunk], error) { @@ -23,7 +21,7 @@ func InvokeTool(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeToo func ValidateToolCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateToolCredentials], ctx *gin.Context) { // create session - session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + session := createSession(r) defer session.Close() baseSSEService(r, func() (*stream.StreamResponse[tool_entities.ValidateCredentialsResult], error) { diff --git a/internal/service/invoke_tool.go b/internal/service/invoke_tool.go index cb7bbaf..ce74c55 100644 --- a/internal/service/invoke_tool.go +++ b/internal/service/invoke_tool.go @@ -3,6 +3,7 @@ package service import ( "github.com/gin-gonic/gin" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon" + "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager" "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" @@ -11,9 +12,16 @@ import ( "github.com/langgenius/dify-plugin-daemon/internal/utils/stream" ) +func createSession[T any](r *plugin_entities.InvokePluginRequest[T]) *session_manager.Session { + session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + runtime := plugin_manager.Get(session.PluginIdentity()) + session.BindRuntime(runtime) + return session +} + func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM], ctx *gin.Context) { // create session - session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + session := createSession(r) defer session.Close() baseSSEService(r, func() (*stream.StreamResponse[model_entities.LLMResultChunk], error) { @@ -23,7 +31,7 @@ func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM] func InvokeTextEmbedding(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding], ctx *gin.Context) { // create session - session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + session := createSession(r) defer session.Close() baseSSEService(r, func() (*stream.StreamResponse[model_entities.TextEmbeddingResult], error) { @@ -33,7 +41,7 @@ func InvokeTextEmbedding(r *plugin_entities.InvokePluginRequest[requests.Request func InvokeRerank(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank], ctx *gin.Context) { // create session - session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + session := createSession(r) defer session.Close() baseSSEService(r, func() (*stream.StreamResponse[model_entities.RerankResult], error) { @@ -43,7 +51,7 @@ func InvokeRerank(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeR func InvokeTTS(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS], ctx *gin.Context) { // create session - session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + session := createSession(r) defer session.Close() baseSSEService(r, func() (*stream.StreamResponse[model_entities.TTSResult], error) { @@ -53,7 +61,7 @@ func InvokeTTS(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS] func InvokeSpeech2Text(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text], ctx *gin.Context) { // create session - session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + session := createSession(r) defer session.Close() baseSSEService(r, func() (*stream.StreamResponse[model_entities.Speech2TextResult], error) { @@ -63,7 +71,7 @@ func InvokeSpeech2Text(r *plugin_entities.InvokePluginRequest[requests.RequestIn func InvokeModeration(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration], ctx *gin.Context) { // create session - session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + session := createSession(r) defer session.Close() baseSSEService(r, func() (*stream.StreamResponse[model_entities.ModerationResult], error) { @@ -73,7 +81,7 @@ func InvokeModeration(r *plugin_entities.InvokePluginRequest[requests.RequestInv func ValidateProviderCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials], ctx *gin.Context) { // create session - session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + session := createSession(r) defer session.Close() baseSSEService(r, func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) { @@ -83,7 +91,7 @@ func ValidateProviderCredentials(r *plugin_entities.InvokePluginRequest[requests func ValidateModelCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials], ctx *gin.Context) { // create session - session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)) + session := createSession(r) defer session.Close() baseSSEService(r, func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) { diff --git a/internal/service/session.go b/internal/service/session.go new file mode 100644 index 0000000..6d43c33 --- /dev/null +++ b/internal/service/session.go @@ -0,0 +1 @@ +package service diff --git a/internal/utils/parser/map2struct.go b/internal/utils/parser/map2struct.go new file mode 100644 index 0000000..34b8579 --- /dev/null +++ b/internal/utils/parser/map2struct.go @@ -0,0 +1,29 @@ +package parser + +import ( + "fmt" + + "github.com/mitchellh/mapstructure" +) + +func MapToStruct[T any](m map[string]any) (*T, error) { + var s T + decoder := &mapstructure.DecoderConfig{ + Metadata: nil, + Result: &s, + TagName: "json", + Squash: true, + } + + d, err := mapstructure.NewDecoder(decoder) + if err != nil { + return nil, fmt.Errorf("error creating decoder: %s", err) + } + + err = d.Decode(m) + if err != nil { + return nil, fmt.Errorf("error decoding map: %s", err) + } + + return &s, nil +} diff --git a/internal/utils/parser/map2struct_test.go b/internal/utils/parser/map2struct_test.go new file mode 100644 index 0000000..f180925 --- /dev/null +++ b/internal/utils/parser/map2struct_test.go @@ -0,0 +1,48 @@ +package parser + +import "testing" + +func TestMapToStruct(t *testing.T) { + m := map[string]any{ + "result": "result", + "inherit": map[string]any{ + "inherit_result": "result", + }, + "object": map[string]any{ + "a": 1, + }, + } + + type p struct { + Inherit struct { + InheritResult string `json:"inherit_result"` + } + } + + type s struct { + p + + Result string `json:"result"` + Object struct { + A int `json:"a"` + } `json:"object"` + } + + result, err := MapToStruct[s](m) + if err != nil { + t.Error(err) + } + + if result.Result != "result" { + t.Error("result should be result") + } + + if result.Inherit.InheritResult != "result" { + t.Error("inherit_result should be result") + } + + if result.Object.A != 1 { + t.Error("a should be 1") + } + +} diff --git a/internal/utils/parser/struct2map.go b/internal/utils/parser/struct2map.go index 7405c90..bbb2610 100644 --- a/internal/utils/parser/struct2map.go +++ b/internal/utils/parser/struct2map.go @@ -1,48 +1,28 @@ package parser import ( - "reflect" - "unicode" + "github.com/mitchellh/mapstructure" ) func StructToMap(data interface{}) map[string]interface{} { result := make(map[string]interface{}) - val := reflect.ValueOf(data) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - for i := 0; i < val.NumField(); i++ { - field := val.Field(i) - typeField := val.Type().Field(i) - fieldName := toSnakeCase(typeField.Name) - if typeField.Anonymous { - embeddedFields := StructToMap(field.Interface()) - for k, v := range embeddedFields { - result[k] = v - } - } else { - result[fieldName] = field.Interface() - } + decoder := &mapstructure.DecoderConfig{ + Metadata: nil, + Result: &result, + TagName: "json", + Squash: true, } - return result -} -func toSnakeCase(str string) string { - runes := []rune(str) - length := len(runes) - var out []rune + d, err := mapstructure.NewDecoder(decoder) + if err != nil { + return nil + } - for i := 0; i < length; i++ { - if unicode.IsUpper(runes[i]) { - if i > 0 { - out = append(out, '_') - } - out = append(out, unicode.ToLower(runes[i])) - } else { - out = append(out, runes[i]) - } + err = d.Decode(data) + if err != nil { + return nil } - return string(out) + return result } diff --git a/internal/utils/parser/struct2map_test.go b/internal/utils/parser/struct2map_test.go new file mode 100644 index 0000000..8166437 --- /dev/null +++ b/internal/utils/parser/struct2map_test.go @@ -0,0 +1,32 @@ +package parser + +import "testing" + +func TestStruct2Map(t *testing.T) { + type Base struct { + A int `json:"a"` + } + + type p struct { + Base + + B int `json:"b"` + } + + d := p{ + Base: Base{ + A: 1, + }, + B: 2, + } + + result := StructToMap(d) + + if result["a"] != 1 { + t.Error("a should be 1") + } + + if result["b"] != 2 { + t.Error("b should be 2") + } +}