diff --git a/cmd/commandline/init/category.go b/cmd/commandline/init/category.go index 58cd570..dcc4b47 100644 --- a/cmd/commandline/init/category.go +++ b/cmd/commandline/init/category.go @@ -12,7 +12,12 @@ type category struct { var categories = []string{ "tool", - "model", + "llm", + "text-embedding", + "rerank", + "tts", + "speech2text", + "moderation", "extension", } @@ -56,11 +61,6 @@ func (c category) Update(msg tea.Msg) (subMenu, subMenuEvent, tea.Cmd) { c.cursor = 0 } case "enter": - if c.cursor != 0 { - c.cursor = 0 - return c, SUB_MENU_EVENT_NONE, nil - } - return c, SUB_MENU_EVENT_NEXT, nil } } diff --git a/cmd/commandline/init/init.go b/cmd/commandline/init/init.go index 5f6c372..cf33c11 100644 --- a/cmd/commandline/init/init.go +++ b/cmd/commandline/init/init.go @@ -141,6 +141,19 @@ func (m model) createPlugin() { manifest.Plugins.Tools = []string{fmt.Sprintf("provider/%s.yaml", manifest.Name)} } + if category_string == "llm" || + category_string == "text-embedding" || + category_string == "speech2text" || + category_string == "moderation" || + category_string == "rerank" || + category_string == "tts" { + manifest.Plugins.Models = []string{fmt.Sprintf("provider/%s.yaml", manifest.Name)} + } + + if category_string == "extension" { + manifest.Plugins.Endpoints = []string{fmt.Sprintf("group/%s.yaml", manifest.Name)} + } + manifest.Meta = plugin_entities.PluginMeta{ Version: "0.0.1", Arch: []constants.Arch{ diff --git a/cmd/commandline/init/python.go b/cmd/commandline/init/python.go index 3c00473..19503e1 100644 --- a/cmd/commandline/init/python.go +++ b/cmd/commandline/init/python.go @@ -1,13 +1,16 @@ package init import ( + "bytes" _ "embed" "fmt" + "html/template" "os" "path/filepath" - "strings" "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" + "github.com/langgenius/dify-plugin-daemon/internal/utils/parser" + "github.com/langgenius/dify-plugin-daemon/internal/utils/strings" ) //go:embed templates/python/main.py @@ -28,6 +31,87 @@ var PYTHON_TOOL_PY_TEMPLATE []byte //go:embed templates/python/tool_provider.py var PYTHON_TOOL_PROVIDER_PY_TEMPLATE []byte +//go:embed templates/python/model_provider.py +var PYTHON_MODEL_PROVIDER_PY_TEMPLATE []byte + +//go:embed templates/python/model_provider.yaml +var PYTHON_MODEL_PROVIDER_TEMPLATE []byte + +//go:embed templates/python/llm.py +var PYTHON_LLM_TEMPLATE []byte + +//go:embed templates/python/llm.yaml +var PYTHON_LLM_MANIFEST_TEMPLATE []byte + +//go:embed templates/python/text-embedding.py +var PYTHON_TEXT_EMBEDDING_TEMPLATE []byte + +//go:embed templates/python/text-embedding.yaml +var PYTHON_TEXT_EMBEDDING_MANIFEST_TEMPLATE []byte + +//go:embed templates/python/rerank.py +var PYTHON_RERANK_TEMPLATE []byte + +//go:embed templates/python/rerank.yaml +var PYTHON_RERANK_MANIFEST_TEMPLATE []byte + +//go:embed templates/python/tts.py +var PYTHON_TTS_TEMPLATE []byte + +//go:embed templates/python/tts.yaml +var PYTHON_TTS_MANIFEST_TEMPLATE []byte + +//go:embed templates/python/speech2text.py +var PYTHON_SPEECH2TEXT_TEMPLATE []byte + +//go:embed templates/python/speech2text.yaml +var PYTHON_SPEECH2TEXT_MANIFEST_TEMPLATE []byte + +//go:embed templates/python/moderation.py +var PYTHON_MODERATION_TEMPLATE []byte + +//go:embed templates/python/moderation.yaml +var PYTHON_MODERATION_MANIFEST_TEMPLATE []byte + +//go:embed templates/python/endpoint_group.yaml +var PYTHON_ENDPOINT_GROUP_MANIFEST_TEMPLATE []byte + +//go:embed templates/python/endpoint.py +var PYTHON_ENDPOINT_TEMPLATE []byte + +//go:embed templates/python/endpoint.yaml +var PYTHON_ENDPOINT_MANIFEST_TEMPLATE []byte + +func renderTemplate( + original_template []byte, manifest *plugin_entities.PluginDeclaration, supported_model_types []string, +) (string, error) { + tmpl := template.Must(template.New("").Funcs(template.FuncMap{ + "SnakeToCamel": parser.SnakeToCamel, + "HasSubstring": func(substring string, haystack []string) bool { + return strings.Find(haystack, substring) + }, + }).Parse(string(original_template))) + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, map[string]interface{}{ + "PluginName": manifest.Name, + "Author": manifest.Author, + "PluginDescription": manifest.Description.EnUS, + "SupportedModelTypes": supported_model_types, + }); err != nil { + return "", err + } + return buf.String(), nil +} + +func writeFile(path string, content string) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + return os.WriteFile(path, []byte(content), 0o644) +} + func createPythonEnvironment( root string, entrypoint string, manifest *plugin_entities.PluginDeclaration, category string, ) error { @@ -54,88 +138,56 @@ func createPythonEnvironment( } } - return nil -} + if category == "extension" { + if err := createPythonEndpointGroup(root, manifest); err != nil { + return err + } -func createPythonTool(root string, manifest *plugin_entities.PluginDeclaration) error { - // create the tool - tool_dir := filepath.Join(root, "tools") - if err := os.MkdirAll(tool_dir, 0o755); err != nil { - return err - } - // replace the plugin name/author/description in the template - tool_file_content := strings.ReplaceAll( - string(PYTHON_TOOL_PY_TEMPLATE), "{{plugin_name}}", manifest.Name, - ) - tool_file_content = strings.ReplaceAll( - tool_file_content, "{{author}}", manifest.Author, - ) - tool_file_content = strings.ReplaceAll( - tool_file_content, "{{plugin_description}}", manifest.Description.EnUS, - ) - tool_file_path := filepath.Join(tool_dir, fmt.Sprintf("%s.py", manifest.Name)) - if err := os.WriteFile(tool_file_path, []byte(tool_file_content), 0o644); err != nil { - return err + if err := createPythonEndpoint(root, manifest); err != nil { + return err + } } - // create the tool manifest - tool_manifest_file_path := filepath.Join(tool_dir, fmt.Sprintf("%s.yaml", manifest.Name)) - if err := os.WriteFile(tool_manifest_file_path, PYTHON_TOOL_TEMPLATE, 0o644); err != nil { - return err + if category == "llm" || category == "text-embedding" || category == "speech2text" || category == "moderation" || category == "rerank" || category == "tts" { + if err := createPythonModelProvider(root, manifest, []string{category}); err != nil { + return err + } } - tool_manifest_file_content := strings.ReplaceAll( - string(PYTHON_TOOL_TEMPLATE), "{{plugin_name}}", manifest.Name, - ) - tool_manifest_file_content = strings.ReplaceAll( - tool_manifest_file_content, "{{author}}", manifest.Author, - ) - tool_manifest_file_content = strings.ReplaceAll( - tool_manifest_file_content, "{{plugin_description}}", manifest.Description.EnUS, - ) - if err := os.WriteFile(tool_manifest_file_path, []byte(tool_manifest_file_content), 0o644); err != nil { - return err + + if category == "llm" { + if err := createPythonLLM(root, manifest); err != nil { + return err + } } - return nil -} + if category == "text-embedding" { + if err := createPythonTextEmbedding(root, manifest); err != nil { + return err + } + } -func createPythonToolProvider(root string, manifest *plugin_entities.PluginDeclaration) error { - // create the tool provider - tool_provider_dir := filepath.Join(root, "provider") - if err := os.MkdirAll(tool_provider_dir, 0o755); err != nil { - return err + if category == "speech2text" { + if err := createPythonSpeech2Text(root, manifest); err != nil { + return err + } } - // replace the plugin name/author/description in the template - tool_provider_file_content := strings.ReplaceAll( - string(PYTHON_TOOL_PROVIDER_PY_TEMPLATE), "{{plugin_name}}", manifest.Name, - ) - tool_provider_file_content = strings.ReplaceAll( - tool_provider_file_content, "{{author}}", manifest.Author, - ) - tool_provider_file_content = strings.ReplaceAll( - tool_provider_file_content, "{{plugin_description}}", manifest.Description.EnUS, - ) - tool_provider_file_path := filepath.Join(tool_provider_dir, fmt.Sprintf("%s.py", manifest.Name)) - if err := os.WriteFile(tool_provider_file_path, []byte(tool_provider_file_content), 0o644); err != nil { - return err + + if category == "moderation" { + if err := createPythonModeration(root, manifest); err != nil { + return err + } } - // create the tool provider manifest - tool_provider_manifest_file_path := filepath.Join(tool_provider_dir, fmt.Sprintf("%s.yaml", manifest.Name)) - if err := os.WriteFile(tool_provider_manifest_file_path, PYTHON_TOOL_PROVIDER_TEMPLATE, 0o644); err != nil { - return err + if category == "rerank" { + if err := createPythonRerank(root, manifest); err != nil { + return err + } } - tool_provider_manifest_file_content := strings.ReplaceAll( - string(PYTHON_TOOL_PROVIDER_TEMPLATE), "{{plugin_name}}", manifest.Name, - ) - tool_provider_manifest_file_content = strings.ReplaceAll( - tool_provider_manifest_file_content, "{{author}}", manifest.Author, - ) - tool_provider_manifest_file_content = strings.ReplaceAll( - tool_provider_manifest_file_content, "{{plugin_description}}", manifest.Description.EnUS, - ) - if err := os.WriteFile(tool_provider_manifest_file_path, []byte(tool_provider_manifest_file_content), 0o644); err != nil { - return err + + if category == "tts" { + if err := createPythonTTS(root, manifest); err != nil { + return err + } } return nil diff --git a/cmd/commandline/init/python_categories.go b/cmd/commandline/init/python_categories.go new file mode 100644 index 0000000..5d09ca3 --- /dev/null +++ b/cmd/commandline/init/python_categories.go @@ -0,0 +1,241 @@ +package init + +import ( + "fmt" + "path/filepath" + + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" +) + +func createPythonTool(root string, manifest *plugin_entities.PluginDeclaration) error { + tool_file_content, err := renderTemplate(PYTHON_TOOL_PY_TEMPLATE, manifest, []string{""}) + if err != nil { + return err + } + tool_file_path := filepath.Join(root, "tools", fmt.Sprintf("%s.py", manifest.Name)) + if err := writeFile(tool_file_path, tool_file_content); err != nil { + return err + } + + tool_manifest_file_path := filepath.Join(root, "tools", fmt.Sprintf("%s.yaml", manifest.Name)) + tool_manifest_file_content, err := renderTemplate(PYTHON_TOOL_TEMPLATE, manifest, []string{""}) + if err != nil { + return err + } + if err := writeFile(tool_manifest_file_path, tool_manifest_file_content); err != nil { + return err + } + + return nil +} + +func createPythonToolProvider(root string, manifest *plugin_entities.PluginDeclaration) error { + tool_provider_file_content, err := renderTemplate(PYTHON_TOOL_PROVIDER_PY_TEMPLATE, manifest, []string{""}) + if err != nil { + return err + } + tool_provider_file_path := filepath.Join(root, "provider", fmt.Sprintf("%s.py", manifest.Name)) + if err := writeFile(tool_provider_file_path, tool_provider_file_content); err != nil { + return err + } + + tool_provider_manifest_file_content, err := renderTemplate(PYTHON_TOOL_PROVIDER_TEMPLATE, manifest, []string{""}) + if err != nil { + return err + } + tool_provider_manifest_file_path := filepath.Join(root, "provider", fmt.Sprintf("%s.yaml", manifest.Name)) + if err := writeFile(tool_provider_manifest_file_path, tool_provider_manifest_file_content); err != nil { + return err + } + + return nil +} + +func createPythonEndpointGroup(root string, manifest *plugin_entities.PluginDeclaration) error { + endpoint_group_file_content, err := renderTemplate(PYTHON_ENDPOINT_GROUP_MANIFEST_TEMPLATE, manifest, []string{""}) + if err != nil { + return err + } + endpoint_group_file_path := filepath.Join(root, "group", fmt.Sprintf("%s.yaml", manifest.Name)) + if err := writeFile(endpoint_group_file_path, endpoint_group_file_content); err != nil { + return err + } + + return nil +} + +func createPythonEndpoint(root string, manifest *plugin_entities.PluginDeclaration) error { + endpoint_file_content, err := renderTemplate(PYTHON_ENDPOINT_MANIFEST_TEMPLATE, manifest, []string{""}) + if err != nil { + return err + } + endpoint_file_path := filepath.Join(root, "endpoints", fmt.Sprintf("%s.yaml", manifest.Name)) + if err := writeFile(endpoint_file_path, endpoint_file_content); err != nil { + return err + } + + endpoint_py_file_content, err := renderTemplate(PYTHON_ENDPOINT_TEMPLATE, manifest, []string{""}) + if err != nil { + return err + } + endpoint_py_file_path := filepath.Join(root, "endpoints", fmt.Sprintf("%s.py", manifest.Name)) + if err := writeFile(endpoint_py_file_path, endpoint_py_file_content); err != nil { + return err + } + + return nil +} + +func createPythonLLM(root string, manifest *plugin_entities.PluginDeclaration) error { + llm_file_content, err := renderTemplate(PYTHON_LLM_MANIFEST_TEMPLATE, manifest, []string{"llm"}) + if err != nil { + return err + } + llm_file_path := filepath.Join(root, "models", "llm", "llm.yaml") + if err := writeFile(llm_file_path, llm_file_content); err != nil { + return err + } + + llm_py_file_content, err := renderTemplate(PYTHON_LLM_TEMPLATE, manifest, []string{"llm"}) + if err != nil { + return err + } + llm_py_file_path := filepath.Join(root, "models", "llm", "llm.py") + if err := writeFile(llm_py_file_path, llm_py_file_content); err != nil { + return err + } + + return nil +} + +func createPythonTextEmbedding(root string, manifest *plugin_entities.PluginDeclaration) error { + text_embedding_file_content, err := renderTemplate(PYTHON_TEXT_EMBEDDING_MANIFEST_TEMPLATE, manifest, []string{"text_embedding"}) + if err != nil { + return err + } + text_embedding_file_path := filepath.Join(root, "models", "text_embedding", "text_embedding.yaml") + if err := writeFile(text_embedding_file_path, text_embedding_file_content); err != nil { + return err + } + + text_embedding_py_file_content, err := renderTemplate(PYTHON_TEXT_EMBEDDING_TEMPLATE, manifest, []string{"text_embedding"}) + if err != nil { + return err + } + text_embedding_py_file_path := filepath.Join(root, "models", "text_embedding", "text_embedding.py") + if err := writeFile(text_embedding_py_file_path, text_embedding_py_file_content); err != nil { + return err + } + + return nil +} + +func createPythonRerank(root string, manifest *plugin_entities.PluginDeclaration) error { + rerank_file_content, err := renderTemplate(PYTHON_RERANK_MANIFEST_TEMPLATE, manifest, []string{"rerank"}) + if err != nil { + return err + } + rerank_file_path := filepath.Join(root, "models", "rerank", "rerank.yaml") + if err := writeFile(rerank_file_path, rerank_file_content); err != nil { + return err + } + + rerank_py_file_content, err := renderTemplate(PYTHON_RERANK_TEMPLATE, manifest, []string{"rerank"}) + if err != nil { + return err + } + rerank_py_file_path := filepath.Join(root, "models", "rerank", "rerank.py") + if err := writeFile(rerank_py_file_path, rerank_py_file_content); err != nil { + return err + } + + return nil +} + +func createPythonTTS(root string, manifest *plugin_entities.PluginDeclaration) error { + tts_file_content, err := renderTemplate(PYTHON_TTS_MANIFEST_TEMPLATE, manifest, []string{"tts"}) + if err != nil { + return err + } + tts_file_path := filepath.Join(root, "models", "tts", "tts.yaml") + if err := writeFile(tts_file_path, tts_file_content); err != nil { + return err + } + + tts_py_file_content, err := renderTemplate(PYTHON_TTS_TEMPLATE, manifest, []string{"tts"}) + if err != nil { + return err + } + tts_py_file_path := filepath.Join(root, "models", "tts", "tts.py") + if err := writeFile(tts_py_file_path, tts_py_file_content); err != nil { + return err + } + + return nil +} + +func createPythonSpeech2Text(root string, manifest *plugin_entities.PluginDeclaration) error { + speech2text_file_content, err := renderTemplate(PYTHON_SPEECH2TEXT_MANIFEST_TEMPLATE, manifest, []string{"speech2text"}) + if err != nil { + return err + } + speech2text_file_path := filepath.Join(root, "models", "speech2text", "speech2text.yaml") + if err := writeFile(speech2text_file_path, speech2text_file_content); err != nil { + return err + } + + speech2text_py_file_content, err := renderTemplate(PYTHON_SPEECH2TEXT_TEMPLATE, manifest, []string{"speech2text"}) + if err != nil { + return err + } + speech2text_py_file_path := filepath.Join(root, "models", "speech2text", "speech2text.py") + if err := writeFile(speech2text_py_file_path, speech2text_py_file_content); err != nil { + return err + } + + return nil +} + +func createPythonModeration(root string, manifest *plugin_entities.PluginDeclaration) error { + moderation_file_content, err := renderTemplate(PYTHON_MODERATION_MANIFEST_TEMPLATE, manifest, []string{"moderation"}) + if err != nil { + return err + } + moderation_file_path := filepath.Join(root, "models", "moderation", "moderation.yaml") + if err := writeFile(moderation_file_path, moderation_file_content); err != nil { + return err + } + + moderation_py_file_content, err := renderTemplate(PYTHON_MODERATION_TEMPLATE, manifest, []string{"moderation"}) + if err != nil { + return err + } + moderation_py_file_path := filepath.Join(root, "models", "moderation", "moderation.py") + if err := writeFile(moderation_py_file_path, moderation_py_file_content); err != nil { + return err + } + + return nil +} + +func createPythonModelProvider(root string, manifest *plugin_entities.PluginDeclaration, supported_model_types []string) error { + provider_file_content, err := renderTemplate(PYTHON_MODEL_PROVIDER_PY_TEMPLATE, manifest, supported_model_types) + if err != nil { + return err + } + provider_file_path := filepath.Join(root, "provider", fmt.Sprintf("%s.py", manifest.Name)) + if err := writeFile(provider_file_path, provider_file_content); err != nil { + return err + } + + provider_manifest_file_content, err := renderTemplate(PYTHON_MODEL_PROVIDER_TEMPLATE, manifest, supported_model_types) + if err != nil { + return err + } + provider_manifest_file_path := filepath.Join(root, "provider", fmt.Sprintf("%s.yaml", manifest.Name)) + if err := writeFile(provider_manifest_file_path, provider_manifest_file_content); err != nil { + return err + } + + return nil +} diff --git a/cmd/commandline/init/render_template_test.go b/cmd/commandline/init/render_template_test.go new file mode 100644 index 0000000..a92c28d --- /dev/null +++ b/cmd/commandline/init/render_template_test.go @@ -0,0 +1,38 @@ +package init + +import ( + "strings" + "testing" + + "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" +) + +func TestRenderPythonToolTemplate(t *testing.T) { + manifest := &plugin_entities.PluginDeclaration{ + PluginDeclarationWithoutAdvancedFields: plugin_entities.PluginDeclarationWithoutAdvancedFields{ + Name: "test", + Author: "test", + Description: plugin_entities.I18nObject{ + EnUS: "test", + }, + }, + } + + content, err := renderTemplate(PYTHON_TOOL_PY_TEMPLATE, manifest, []string{""}) + if err != nil { + t.Errorf("failed to render template: %v", err) + } + + if !strings.Contains(content, "TestTool") { + t.Errorf("template content does not contain TestTool, snakeToCamel failed") + } + + content, err = renderTemplate(PYTHON_TOOL_PROVIDER_TEMPLATE, manifest, []string{""}) + if err != nil { + t.Errorf("failed to render template: %v", err) + } + + if !strings.Contains(content, "test") { + t.Errorf("template content does not contain TestTool, snakeToCamel failed") + } +} diff --git a/cmd/commandline/init/templates/python/endpoint.py b/cmd/commandline/init/templates/python/endpoint.py new file mode 100644 index 0000000..1ffb80d --- /dev/null +++ b/cmd/commandline/init/templates/python/endpoint.py @@ -0,0 +1,17 @@ +import time +from typing import Mapping +from werkzeug import Request, Response +from dify_plugin import Endpoint + + +class {{ .PluginName | SnakeToCamel }}Endpoint(Endpoint): + def _invoke(self, r: Request, values: Mapping, settings: Mapping) -> Response: + """ + Invokes the endpoint with the given request. + """ + def generator(): + for i in range(10): + time.sleep(1) + yield f"{i}
" + + return Response(generator(), status=200, content_type="text/html") diff --git a/cmd/commandline/init/templates/python/endpoint.yaml b/cmd/commandline/init/templates/python/endpoint.yaml new file mode 100644 index 0000000..7cb0ccb --- /dev/null +++ b/cmd/commandline/init/templates/python/endpoint.yaml @@ -0,0 +1,5 @@ +path: "/uwu/{{ .PluginName }}/uwu" +method: "GET" +extra: + python: + source: "endpoints/{{ .PluginName }}.py" diff --git a/cmd/commandline/init/templates/python/endpoint_group.yaml b/cmd/commandline/init/templates/python/endpoint_group.yaml new file mode 100644 index 0000000..f12e864 --- /dev/null +++ b/cmd/commandline/init/templates/python/endpoint_group.yaml @@ -0,0 +1,14 @@ +settings: + - name: api_key + type: secret-input + required: true + label: + en_US: API key + zh_Hans: API key + pt_BR: API key + placeholder: + en_US: Please input your API key + zh_Hans: 请输入你的 API key + pt_BR: Please input your API key +endpoints: + - endpoints/{{ .PluginName }}.yaml diff --git a/cmd/commandline/init/templates/python/llm.py b/cmd/commandline/init/templates/python/llm.py new file mode 100644 index 0000000..7b47368 --- /dev/null +++ b/cmd/commandline/init/templates/python/llm.py @@ -0,0 +1,109 @@ +import logging +from collections.abc import Generator +from typing import Optional, Union + +from dify_plugin.entities import I18nObject +from dify_plugin.errors.model import ( + CredentialsValidateFailedError, +) +from dify_plugin.entities.model import ( + AIModelEntity, + FetchFrom, + ModelType, +) +from dify_plugin.entities.model.llm import ( + LLMResult, +) +from dify_plugin.entities.model.message import ( + PromptMessage, + PromptMessageTool, +) + +logger = logging.getLogger(__name__) + + +class {{ .PluginName | SnakeToCamel }}LargeLanguageModel(LargeLanguageModel): + """ + Model class for {{ .PluginName }} large language model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + pass + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + return 0 + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + pass + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity: + """ + If your model supports fine-tuning, this method returns the schema of the base model + but renamed to the fine-tuned model name. + + :param model: model name + :param credentials: credentials + + :return: model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(zh_Hans=model, en_US=model), + model_type=ModelType.LLM, + features=[], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + parameter_rules=[], + ) + + return entity diff --git a/cmd/commandline/init/templates/python/llm.yaml b/cmd/commandline/init/templates/python/llm.yaml new file mode 100644 index 0000000..4a0e2ef --- /dev/null +++ b/cmd/commandline/init/templates/python/llm.yaml @@ -0,0 +1,33 @@ +model: gpt-3.5-turbo-16k-0613 +label: + zh_Hans: gpt-3.5-turbo-16k-0613 + en_US: gpt-3.5-turbo-16k-0613 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 16385 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16385 + - name: response_format + use_template: response_format +pricing: + input: '0.003' + output: '0.004' + unit: '0.001' + currency: USD diff --git a/cmd/commandline/init/templates/python/model_provider.py b/cmd/commandline/init/templates/python/model_provider.py new file mode 100644 index 0000000..68d4a56 --- /dev/null +++ b/cmd/commandline/init/templates/python/model_provider.py @@ -0,0 +1,27 @@ +import logging +from collections.abc import Mapping + +from dify_plugin import ModelProvider +from dify_plugin.entities.model import ModelType +from dify_plugin.errors.model import CredentialsValidateFailedError + +logger = logging.getLogger(__name__) + + +class {{ .PluginName | SnakeToCamel }}ModelProvider(ModelProvider): + def validate_provider_credentials(self, credentials: Mapping) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + pass + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception( + f"{self.get_provider_schema().provider} credentials validate failed" + ) + raise ex diff --git a/cmd/commandline/init/templates/python/model_provider.yaml b/cmd/commandline/init/templates/python/model_provider.yaml new file mode 100644 index 0000000..97c826d --- /dev/null +++ b/cmd/commandline/init/templates/python/model_provider.yaml @@ -0,0 +1,96 @@ +provider: {{ .PluginName }} +label: + en_US: {{ .PluginName | SnakeToCamel }} +description: + en_US: Models provided by {{ .PluginName }}. + zh_Hans: {{ .PluginName | SnakeToCamel }} 提供的模型。 +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#E5E7EB" +help: + title: + en_US: Get your API Key from {{ .PluginName }} + zh_Hans: 从 {{ .PluginName | SnakeToCamel }} 获取 API Key + url: + en_US: https://__put_your_url_here__/account/api-keys +supported_model_types: +{{- range .SupportedModelTypes }} + - {{ . }} +{{- end }} +configurate_methods: + - predefined-model + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: openai_api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key +provider_credential_schema: + credential_form_schemas: + - variable: openai_api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key +models: +{{- if HasSubstring "llm" .SupportedModelTypes }} + llm: + predefined: + - "models/llm/*.yaml" +{{- end }} +{{- if HasSubstring "text_embedding" .SupportedModelTypes }} + text_embedding: + predefined: + - "models/text_embedding/*.yaml" +{{- end }} +{{- if HasSubstring "tts" .SupportedModelTypes }} + tts: + predefined: + - "models/tts/*.yaml" +{{- end }} +{{- if HasSubstring "speech2text" .SupportedModelTypes }} + speech2text: + predefined: + - "models/speech2text/*.yaml" +{{- end }} +{{- if HasSubstring "moderation" .SupportedModelTypes }} + moderation: + predefined: + - "models/moderation/*.yaml" +{{- end }} +extra: + python: + provider_source: provider/openai.py + model_sources: +{{- if HasSubstring "llm" .SupportedModelTypes }} + - "models/llm/llm.py" +{{- end }} +{{- if HasSubstring "text-embedding" .SupportedModelTypes }} + - "models/text_embedding/text_embedding.py" +{{- end }} +{{- if HasSubstring "speech2text" .SupportedModelTypes }} + - "models/speech2text/speech2text.py" +{{- end }} +{{- if HasSubstring "moderation" .SupportedModelTypes }} + - "models/moderation/moderation.py" +{{- end }} +{{- if HasSubstring "tts" .SupportedModelTypes }} + - "models/tts/tts.py" +{{- end }} diff --git a/cmd/commandline/init/templates/python/moderation.py b/cmd/commandline/init/templates/python/moderation.py new file mode 100644 index 0000000..756e3c2 --- /dev/null +++ b/cmd/commandline/init/templates/python/moderation.py @@ -0,0 +1,37 @@ +from typing import Optional + +from dify_plugin.errors.model import CredentialsValidateFailedError +from dify_plugin import ModerationModel + +class {{ .PluginName | SnakeToCamel }}ModerationModel(ModerationModel): + """ + Model class for {{ .PluginName | CamelToTitle }} text moderation model. + """ + + def _invoke(self, model: str, credentials: dict, + text: str, user: Optional[str] = None) \ + -> bool: + """ + Invoke moderation model + + :param model: model name + :param credentials: model credentials + :param text: text to moderate + :param user: unique user id + :return: false if text is safe, true otherwise + """ + # transform credentials to kwargs for model instance + return True + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + pass + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) diff --git a/cmd/commandline/init/templates/python/moderation.yaml b/cmd/commandline/init/templates/python/moderation.yaml new file mode 100644 index 0000000..5ca1809 --- /dev/null +++ b/cmd/commandline/init/templates/python/moderation.yaml @@ -0,0 +1,5 @@ +model: text-moderation-stable +model_type: moderation +model_properties: + max_chunks: 32 + max_characters_per_chunk: 2000 diff --git a/cmd/commandline/init/templates/python/rerank.py b/cmd/commandline/init/templates/python/rerank.py new file mode 100644 index 0000000..c38b235 --- /dev/null +++ b/cmd/commandline/init/templates/python/rerank.py @@ -0,0 +1,104 @@ +from typing import Optional + +import httpx + +from dify_plugin import RerankModel +from dify_plugin.entities import I18nObject +from dify_plugin.entities.model import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from dify_plugin.errors.model import ( + CredentialsValidateFailedError, + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from dify_plugin.entities.model.rerank import ( + RerankDocument, + RerankResult, +) + +class {{ .PluginName | SnakeToCamel }}RerankModel(RerankModel): + """ + Model class for {{ .PluginName | SnakeToCamel }} rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + pass + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema( + self, model: str, credentials: dict + ) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size") or 0) + }, + ) + + return entity diff --git a/cmd/commandline/init/templates/python/rerank.yaml b/cmd/commandline/init/templates/python/rerank.yaml new file mode 100644 index 0000000..acf5767 --- /dev/null +++ b/cmd/commandline/init/templates/python/rerank.yaml @@ -0,0 +1,4 @@ +model: jina-reranker-v2-base-multilingual +model_type: rerank +model_properties: + context_size: 8192 diff --git a/cmd/commandline/init/templates/python/speech2text.py b/cmd/commandline/init/templates/python/speech2text.py new file mode 100644 index 0000000..81f375e --- /dev/null +++ b/cmd/commandline/init/templates/python/speech2text.py @@ -0,0 +1,35 @@ +from typing import IO, Optional + +from dify_plugin.errors.model import CredentialsValidateFailedError + +class {{ .PluginName | SnakeToCamel }}Speech2TextModel(Speech2TextModel): + """ + Model class for OpenAI Speech to text model. + """ + + def _invoke(self, model: str, credentials: dict, + file: IO[bytes], user: Optional[str] = None) \ + -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + pass + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + pass + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) diff --git a/cmd/commandline/init/templates/python/speech2text.yaml b/cmd/commandline/init/templates/python/speech2text.yaml new file mode 100644 index 0000000..6c14c76 --- /dev/null +++ b/cmd/commandline/init/templates/python/speech2text.yaml @@ -0,0 +1,5 @@ +model: whisper-1 +model_type: speech2text +model_properties: + file_upload_limit: 25 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/cmd/commandline/init/templates/python/text-embedding.py b/cmd/commandline/init/templates/python/text-embedding.py new file mode 100644 index 0000000..d037461 --- /dev/null +++ b/cmd/commandline/init/templates/python/text-embedding.py @@ -0,0 +1,57 @@ +from typing import Optional + +import numpy as np + +from dify_plugin.entities.model import EmbeddingInputType +from dify_plugin.errors.model import CredentialsValidateFailedError +from dify_plugin.entities.model.text_embedding import ( + TextEmbeddingResult, +) + +class {{ .PluginName | SnakeToCamel }}TextEmbeddingModel(TextEmbeddingModel): + """ + Model class for {{ .PluginName }} text embedding model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + pass + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return 0 + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + pass + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) diff --git a/cmd/commandline/init/templates/python/text-embedding.yaml b/cmd/commandline/init/templates/python/text-embedding.yaml new file mode 100644 index 0000000..ef1c49b --- /dev/null +++ b/cmd/commandline/init/templates/python/text-embedding.yaml @@ -0,0 +1,9 @@ +model: text-embedding-ada-002 +model_type: text-embedding +model_properties: + context_size: 8097 + max_chunks: 32 +pricing: + input: '0.0001' + unit: '0.001' + currency: USD diff --git a/cmd/commandline/init/templates/python/tool.py b/cmd/commandline/init/templates/python/tool.py index 4a40099..c4e907d 100644 --- a/cmd/commandline/init/templates/python/tool.py +++ b/cmd/commandline/init/templates/python/tool.py @@ -4,7 +4,7 @@ from dify_plugin import Tool from dify_plugin.entities.tool import ToolInvokeMessage -class {{plugin_name}}Tool(Tool): +class {{ .PluginName | SnakeToCamel }}Tool(Tool): def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]: yield self.create_json_message({ "result": "Hello, world!" diff --git a/cmd/commandline/init/templates/python/tool.yaml b/cmd/commandline/init/templates/python/tool.yaml index 06d6a3e..abb8b56 100644 --- a/cmd/commandline/init/templates/python/tool.yaml +++ b/cmd/commandline/init/templates/python/tool.yaml @@ -1,16 +1,16 @@ identity: - name: {{plugin_name}} - author: {{author}} + name: {{ .PluginName }} + author: {{ .Author }} label: - en_US: {{plugin_name}} - zh_Hans: {{plugin_name}} - pt_BR: {{plugin_name}} + en_US: {{ .PluginName }} + zh_Hans: {{ .PluginName }} + pt_BR: {{ .PluginName }} description: human: - en_US: {{plugin_description}} - zh_Hans: {{plugin_description}} - pt_BR: {{plugin_description}} - llm: {{plugin_description}} + en_US: {{ .PluginDescription }} + zh_Hans: {{ .PluginDescription }} + pt_BR: {{ .PluginDescription }} + llm: {{ .PluginDescription }} parameters: - name: query type: string @@ -20,11 +20,11 @@ parameters: zh_Hans: 查询语句 pt_BR: Query string human_description: - en_US: {{plugin_description}} - zh_Hans: {{plugin_description}} - pt_BR: {{plugin_description}} - llm_description: {{plugin_description}} + en_US: {{ .PluginDescription }} + zh_Hans: {{ .PluginDescription }} + pt_BR: {{ .PluginDescription }} + llm_description: {{ .PluginDescription }} form: llm extra: python: - source: tools/{{plugin_name}}.py + source: tools/{{ .PluginName }}.py diff --git a/cmd/commandline/init/templates/python/tool_provider.py b/cmd/commandline/init/templates/python/tool_provider.py index c9c7df8..7e06f3f 100644 --- a/cmd/commandline/init/templates/python/tool_provider.py +++ b/cmd/commandline/init/templates/python/tool_provider.py @@ -4,7 +4,7 @@ from dify_plugin.errors.tool import ToolProviderCredentialValidationError -class {{plugin_name}}Provider(ToolProvider): +class {{ .PluginName | SnakeToCamel }}Provider(ToolProvider): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: """ diff --git a/cmd/commandline/init/templates/python/tool_provider.yaml b/cmd/commandline/init/templates/python/tool_provider.yaml index 91e5680..898e78d 100644 --- a/cmd/commandline/init/templates/python/tool_provider.yaml +++ b/cmd/commandline/init/templates/python/tool_provider.yaml @@ -1,17 +1,17 @@ identity: - author: {{author}} - name: {{plugin_name}} + author: {{ .Author }} + name: {{ .PluginName }} label: - en_US: {{plugin_name}} - zh_Hans: {{plugin_name}} - pt_BR: {{plugin_name}} + en_US: {{ .PluginName }} + zh_Hans: {{ .PluginName }} + pt_BR: {{ .PluginName }} description: - en_US: {{plugin_description}} - zh_Hans: {{plugin_description}} - pt_BR: {{plugin_description}} + en_US: {{ .PluginDescription }} + zh_Hans: {{ .PluginDescription }} + pt_BR: {{ .PluginDescription }} icon: icon.svg tools: - - tools/{{plugin_name}}.yaml + - tools/{{ .PluginName }}.yaml extra: python: - source: provider/{{plugin_name}}.py + source: provider/{{ .PluginName }}.py diff --git a/cmd/commandline/init/templates/python/tts.py b/cmd/commandline/init/templates/python/tts.py new file mode 100644 index 0000000..e69de29 diff --git a/cmd/commandline/init/templates/python/tts.yaml b/cmd/commandline/init/templates/python/tts.yaml new file mode 100644 index 0000000..83969fb --- /dev/null +++ b/cmd/commandline/init/templates/python/tts.yaml @@ -0,0 +1,31 @@ +model: tts-1 +model_type: tts +model_properties: + default_voice: 'alloy' + voices: + - mode: 'alloy' + name: 'Alloy' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + - mode: 'echo' + name: 'Echo' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + - mode: 'fable' + name: 'Fable' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + - mode: 'onyx' + name: 'Onyx' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + - mode: 'nova' + name: 'Nova' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + - mode: 'shimmer' + name: 'Shimmer' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 +pricing: + input: '0.015' + output: '0' + unit: '0.001' + currency: USD diff --git a/internal/utils/parser/camel.go b/internal/utils/parser/camel.go new file mode 100644 index 0000000..0313340 --- /dev/null +++ b/internal/utils/parser/camel.go @@ -0,0 +1,14 @@ +package parser + +import "strings" + +func SnakeToCamel(s string) string { + s = strings.ReplaceAll(s, "-", "_") + s = strings.ReplaceAll(s, " ", "_") + + words := strings.Split(s, "_") + for i, word := range words { + words[i] = strings.ToUpper(word[:1]) + word[1:] + } + return strings.Join(words, "") +}