-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add Novita AI image generation tool, implemented model search, …
…text-to-image and create tile functionalities (#5308) Co-authored-by: crazywoola <[email protected]>
- Loading branch information
1 parent
3828d4c
commit 132f5fb
Showing
10 changed files
with
983 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Any | ||
|
||
from core.tools.errors import ToolProviderCredentialValidationError | ||
from core.tools.provider.builtin.novitaai.tools.novitaai_txt2img import NovitaAiTxt2ImgTool | ||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController | ||
|
||
|
||
class NovitaAIProvider(BuiltinToolProviderController): | ||
def _validate_credentials(self, credentials: dict[str, Any]) -> None: | ||
try: | ||
result = NovitaAiTxt2ImgTool().fork_tool_runtime( | ||
runtime={ | ||
"credentials": credentials, | ||
} | ||
).invoke( | ||
user_id='', | ||
tool_parameters={ | ||
'model_name': 'cinenautXLATRUE_cinenautV10_392434.safetensors', | ||
'prompt': 'a futuristic city with flying cars', | ||
'negative_prompt': '', | ||
'width': 128, | ||
'height': 128, | ||
'image_num': 1, | ||
'guidance_scale': 7.5, | ||
'seed': -1, | ||
'steps': 1, | ||
}, | ||
) | ||
except Exception as e: | ||
raise ToolProviderCredentialValidationError(str(e)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
identity: | ||
author: Xiao Ley | ||
name: novitaai | ||
label: | ||
en_US: Novita AI | ||
zh_Hans: Novita AI | ||
pt_BR: Novita AI | ||
description: | ||
en_US: Innovative AI for Image Generation | ||
zh_Hans: 用于图像生成的创新人工智能。 | ||
pt_BR: Innovative AI for Image Generation | ||
icon: icon.ico | ||
tags: | ||
- image | ||
- productivity | ||
credentials_for_provider: | ||
api_key: | ||
type: secret-input | ||
required: true | ||
label: | ||
en_US: API Key | ||
zh_Hans: API 密钥 | ||
pt_BR: Chave API | ||
placeholder: | ||
en_US: Please enter your Novita AI API key | ||
zh_Hans: 请输入你的 Novita AI API 密钥 | ||
pt_BR: Por favor, insira sua chave de API do Novita AI | ||
help: | ||
en_US: Get your Novita AI API key from Novita AI | ||
zh_Hans: 从 Novita AI 获取您的 Novita AI API 密钥 | ||
pt_BR: Obtenha sua chave de API do Novita AI na Novita AI | ||
url: https://novita.ai |
51 changes: 51 additions & 0 deletions
51
api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from base64 import b64decode | ||
from copy import deepcopy | ||
from typing import Any, Union | ||
|
||
from novita_client import ( | ||
NovitaClient, | ||
) | ||
|
||
from core.tools.entities.tool_entities import ToolInvokeMessage | ||
from core.tools.errors import ToolProviderCredentialValidationError | ||
from core.tools.tool.builtin_tool import BuiltinTool | ||
|
||
|
||
class NovitaAiCreateTileTool(BuiltinTool): | ||
def _invoke(self, | ||
user_id: str, | ||
tool_parameters: dict[str, Any], | ||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | ||
""" | ||
invoke tools | ||
""" | ||
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): | ||
raise ToolProviderCredentialValidationError("Novita AI API Key is required.") | ||
|
||
api_key = self.runtime.credentials.get('api_key') | ||
|
||
client = NovitaClient(api_key=api_key) | ||
param = self._process_parameters(tool_parameters) | ||
client_result = client.create_tile(**param) | ||
|
||
results = [] | ||
results.append( | ||
self.create_blob_message(blob=b64decode(client_result.image_file), | ||
meta={'mime_type': f'image/{client_result.image_type}'}, | ||
save_as=self.VARIABLE_KEY.IMAGE.value) | ||
) | ||
|
||
return results | ||
|
||
def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
process parameters | ||
""" | ||
res_parameters = deepcopy(parameters) | ||
|
||
# delete none and empty | ||
keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ''] | ||
for k in keys_to_delete: | ||
del res_parameters[k] | ||
|
||
return res_parameters |
80 changes: 80 additions & 0 deletions
80
api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
identity: | ||
name: novitaai_createtile | ||
author: Xiao Ley | ||
label: | ||
en_US: Novita AI Create Tile | ||
zh_Hans: Novita AI 创建平铺图案 | ||
description: | ||
human: | ||
en_US: This feature produces images designed for seamless tiling, ideal for creating continuous patterns in fabrics, wallpapers, and various textures. | ||
zh_Hans: 该功能生成设计用于无缝平铺的图像,非常适合用于制作连续图案的织物、壁纸和各种纹理。 | ||
llm: A tool for create images designed for seamless tiling, ideal for creating continuous patterns in fabrics, wallpapers, and various textures. | ||
parameters: | ||
- name: prompt | ||
type: string | ||
required: true | ||
label: | ||
en_US: prompt | ||
zh_Hans: 提示 | ||
human_description: | ||
en_US: Positive prompt word of the created tile, divided by `,`, Range [1, 512]. Only English input is allowed. | ||
zh_Hans: 生成平铺图案的正向提示,用 `,` 分隔,范围 [1, 512]。仅允许输入英文。 | ||
llm_description: Image prompt of Novita AI, you should describe the image you want to generate as a list of words as possible as detailed, divided by `,`, Range [1, 512]. Only English input is allowed. | ||
form: llm | ||
- name: negative_prompt | ||
type: string | ||
required: false | ||
label: | ||
en_US: negative prompt | ||
zh_Hans: 负向提示 | ||
human_description: | ||
en_US: Negtive prompt word of the created tile, divided by `,`, Range [1, 512]. Only English input is allowed. | ||
zh_Hans: 生成平铺图案的负向提示,用 `,` 分隔,范围 [1, 512]。仅允许输入英文。 | ||
llm_description: Image negative prompt of Novita AI, divided by `,`, Range [1, 512]. Only English input is allowed. | ||
form: llm | ||
- name: width | ||
type: number | ||
default: 256 | ||
min: 128 | ||
max: 1024 | ||
required: true | ||
label: | ||
en_US: width | ||
zh_Hans: 宽 | ||
human_description: | ||
en_US: Image width, Range [128, 1024]. | ||
zh_Hans: 图像宽度,范围 [128, 1024] | ||
form: form | ||
- name: height | ||
type: number | ||
default: 256 | ||
min: 128 | ||
max: 1024 | ||
required: true | ||
label: | ||
en_US: height | ||
zh_Hans: 高 | ||
human_description: | ||
en_US: Image height, Range [128, 1024]. | ||
zh_Hans: 图像高度,范围 [128, 1024] | ||
form: form | ||
- name: response_image_type | ||
type: select | ||
default: jpeg | ||
required: false | ||
label: | ||
en_US: response image type | ||
zh_Hans: 响应图像类型 | ||
human_description: | ||
en_US: Response image type, png or jpeg | ||
zh_Hans: 响应图像类型,png 或 jpeg | ||
form: form | ||
options: | ||
- value: jpeg | ||
label: | ||
en_US: jpeg | ||
zh_Hans: jpeg | ||
- value: png | ||
label: | ||
en_US: png | ||
zh_Hans: png |
137 changes: 137 additions & 0 deletions
137
api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import json | ||
from copy import deepcopy | ||
from typing import Any, Union | ||
|
||
from pandas import DataFrame | ||
from yarl import URL | ||
|
||
from core.helper import ssrf_proxy | ||
from core.tools.entities.tool_entities import ToolInvokeMessage | ||
from core.tools.errors import ToolProviderCredentialValidationError | ||
from core.tools.tool.builtin_tool import BuiltinTool | ||
|
||
|
||
class NovitaAiModelQueryTool(BuiltinTool): | ||
_model_query_endpoint = 'https://api.novita.ai/v3/model' | ||
|
||
def _invoke(self, | ||
user_id: str, | ||
tool_parameters: dict[str, Any], | ||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | ||
""" | ||
invoke tools | ||
""" | ||
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): | ||
raise ToolProviderCredentialValidationError("Novita AI API Key is required.") | ||
|
||
api_key = self.runtime.credentials.get('api_key') | ||
headers = { | ||
'Content-Type': 'application/json', | ||
'Authorization': "Bearer " + api_key | ||
} | ||
params = self._process_parameters(tool_parameters) | ||
result_type = params.get('result_type') | ||
del params['result_type'] | ||
|
||
models_data = self._query_models( | ||
models_data=[], | ||
headers=headers, | ||
params=params, | ||
recursive=False if result_type == 'first sd_name' or result_type == 'first name sd_name pair' else True | ||
) | ||
|
||
result_str = '' | ||
if result_type == 'first sd_name': | ||
result_str = models_data[0]['sd_name_in_api'] | ||
elif result_type == 'first name sd_name pair': | ||
result_str = json.dumps({'name': models_data[0]['name'], 'sd_name': models_data[0]['sd_name_in_api']}) | ||
elif result_type == 'sd_name array': | ||
sd_name_array = [model['sd_name_in_api'] for model in models_data] | ||
result_str = json.dumps(sd_name_array) | ||
elif result_type == 'name array': | ||
name_array = [model['name'] for model in models_data] | ||
result_str = json.dumps(name_array) | ||
elif result_type == 'name sd_name pair array': | ||
name_sd_name_pair_array = [{'name': model['name'], 'sd_name': model['sd_name_in_api']} for model in models_data] | ||
result_str = json.dumps(name_sd_name_pair_array) | ||
elif result_type == 'whole info array': | ||
result_str = json.dumps(models_data) | ||
else: | ||
raise NotImplementedError | ||
|
||
return self.create_text_message(result_str) | ||
|
||
def _query_models(self, models_data: list, headers: dict[str, Any], | ||
params: dict[str, Any], pagination_cursor: str = '', recursive: bool = True) -> list: | ||
""" | ||
query models | ||
""" | ||
inside_params = deepcopy(params) | ||
|
||
if pagination_cursor != '': | ||
inside_params['pagination.cursor'] = pagination_cursor | ||
|
||
response = ssrf_proxy.get( | ||
url=str(URL(self._model_query_endpoint)), | ||
headers=headers, | ||
params=params, | ||
timeout=(10, 60) | ||
) | ||
|
||
res_data = response.json() | ||
|
||
models_data.extend(res_data['models']) | ||
|
||
res_data_len = len(res_data['models']) | ||
if res_data_len == 0 or res_data_len < int(params['pagination.limit']) or recursive is False: | ||
# deduplicate | ||
df = DataFrame.from_dict(models_data) | ||
df_unique = df.drop_duplicates(subset=['id']) | ||
models_data = df_unique.to_dict('records') | ||
return models_data | ||
|
||
return self._query_models( | ||
models_data=models_data, | ||
headers=headers, | ||
params=inside_params, | ||
pagination_cursor=res_data['pagination']['next_cursor'] | ||
) | ||
|
||
def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: | ||
""" | ||
process parameters | ||
""" | ||
process_parameters = deepcopy(parameters) | ||
res_parameters = {} | ||
|
||
# delete none or empty | ||
keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ''] | ||
for k in keys_to_delete: | ||
del process_parameters[k] | ||
|
||
if 'query' in process_parameters and process_parameters.get('query') != 'unspecified': | ||
res_parameters['filter.query'] = process_parameters['query'] | ||
|
||
if 'visibility' in process_parameters and process_parameters.get('visibility') != 'unspecified': | ||
res_parameters['filter.visibility'] = process_parameters['visibility'] | ||
|
||
if 'source' in process_parameters and process_parameters.get('source') != 'unspecified': | ||
res_parameters['filter.source'] = process_parameters['source'] | ||
|
||
if 'type' in process_parameters and process_parameters.get('type') != 'unspecified': | ||
res_parameters['filter.types'] = process_parameters['type'] | ||
|
||
if 'is_sdxl' in process_parameters: | ||
if process_parameters['is_sdxl'] == 'true': | ||
res_parameters['filter.is_sdxl'] = True | ||
elif process_parameters['is_sdxl'] == 'false': | ||
res_parameters['filter.is_sdxl'] = False | ||
|
||
res_parameters['result_type'] = process_parameters.get('result_type', 'first sd_name') | ||
|
||
res_parameters['pagination.limit'] = 1 \ | ||
if res_parameters.get('result_type') == 'first sd_name' \ | ||
or res_parameters.get('result_type') == 'first name sd_name pair'\ | ||
else 100 | ||
|
||
return res_parameters |
Oops, something went wrong.