Skip to content

Commit

Permalink
feat: add Novita AI image generation tool, implemented model search, …
Browse files Browse the repository at this point in the history
…text-to-image and create tile functionalities (#5308)

Co-authored-by: crazywoola <[email protected]>
  • Loading branch information
XiaoLey and crazywoola authored Jun 18, 2024
1 parent 3828d4c commit 132f5fb
Show file tree
Hide file tree
Showing 10 changed files with 983 additions and 0 deletions.
Binary file not shown.
30 changes: 30 additions & 0 deletions api/core/tools/provider/builtin/novitaai/novitaai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any

from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.novitaai.tools.novitaai_txt2img import NovitaAiTxt2ImgTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController


class NovitaAIProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
result = NovitaAiTxt2ImgTool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
'model_name': 'cinenautXLATRUE_cinenautV10_392434.safetensors',
'prompt': 'a futuristic city with flying cars',
'negative_prompt': '',
'width': 128,
'height': 128,
'image_num': 1,
'guidance_scale': 7.5,
'seed': -1,
'steps': 1,
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
32 changes: 32 additions & 0 deletions api/core/tools/provider/builtin/novitaai/novitaai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
identity:
author: Xiao Ley
name: novitaai
label:
en_US: Novita AI
zh_Hans: Novita AI
pt_BR: Novita AI
description:
en_US: Innovative AI for Image Generation
zh_Hans: 用于图像生成的创新人工智能。
pt_BR: Innovative AI for Image Generation
icon: icon.ico
tags:
- image
- productivity
credentials_for_provider:
api_key:
type: secret-input
required: true
label:
en_US: API Key
zh_Hans: API 密钥
pt_BR: Chave API
placeholder:
en_US: Please enter your Novita AI API key
zh_Hans: 请输入你的 Novita AI API 密钥
pt_BR: Por favor, insira sua chave de API do Novita AI
help:
en_US: Get your Novita AI API key from Novita AI
zh_Hans: 从 Novita AI 获取您的 Novita AI API 密钥
pt_BR: Obtenha sua chave de API do Novita AI na Novita AI
url: https://novita.ai
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from base64 import b64decode
from copy import deepcopy
from typing import Any, Union

from novita_client import (
NovitaClient,
)

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.tool.builtin_tool import BuiltinTool


class NovitaAiCreateTileTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
raise ToolProviderCredentialValidationError("Novita AI API Key is required.")

api_key = self.runtime.credentials.get('api_key')

client = NovitaClient(api_key=api_key)
param = self._process_parameters(tool_parameters)
client_result = client.create_tile(**param)

results = []
results.append(
self.create_blob_message(blob=b64decode(client_result.image_file),
meta={'mime_type': f'image/{client_result.image_type}'},
save_as=self.VARIABLE_KEY.IMAGE.value)
)

return results

def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
process parameters
"""
res_parameters = deepcopy(parameters)

# delete none and empty
keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == '']
for k in keys_to_delete:
del res_parameters[k]

return res_parameters
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
identity:
name: novitaai_createtile
author: Xiao Ley
label:
en_US: Novita AI Create Tile
zh_Hans: Novita AI 创建平铺图案
description:
human:
en_US: This feature produces images designed for seamless tiling, ideal for creating continuous patterns in fabrics, wallpapers, and various textures.
zh_Hans: 该功能生成设计用于无缝平铺的图像,非常适合用于制作连续图案的织物、壁纸和各种纹理。
llm: A tool for create images designed for seamless tiling, ideal for creating continuous patterns in fabrics, wallpapers, and various textures.
parameters:
- name: prompt
type: string
required: true
label:
en_US: prompt
zh_Hans: 提示
human_description:
en_US: Positive prompt word of the created tile, divided by `,`, Range [1, 512]. Only English input is allowed.
zh_Hans: 生成平铺图案的正向提示,用 `,` 分隔,范围 [1, 512]。仅允许输入英文。
llm_description: Image prompt of Novita AI, you should describe the image you want to generate as a list of words as possible as detailed, divided by `,`, Range [1, 512]. Only English input is allowed.
form: llm
- name: negative_prompt
type: string
required: false
label:
en_US: negative prompt
zh_Hans: 负向提示
human_description:
en_US: Negtive prompt word of the created tile, divided by `,`, Range [1, 512]. Only English input is allowed.
zh_Hans: 生成平铺图案的负向提示,用 `,` 分隔,范围 [1, 512]。仅允许输入英文。
llm_description: Image negative prompt of Novita AI, divided by `,`, Range [1, 512]. Only English input is allowed.
form: llm
- name: width
type: number
default: 256
min: 128
max: 1024
required: true
label:
en_US: width
zh_Hans:
human_description:
en_US: Image width, Range [128, 1024].
zh_Hans: 图像宽度,范围 [128, 1024]
form: form
- name: height
type: number
default: 256
min: 128
max: 1024
required: true
label:
en_US: height
zh_Hans:
human_description:
en_US: Image height, Range [128, 1024].
zh_Hans: 图像高度,范围 [128, 1024]
form: form
- name: response_image_type
type: select
default: jpeg
required: false
label:
en_US: response image type
zh_Hans: 响应图像类型
human_description:
en_US: Response image type, png or jpeg
zh_Hans: 响应图像类型,png 或 jpeg
form: form
options:
- value: jpeg
label:
en_US: jpeg
zh_Hans: jpeg
- value: png
label:
en_US: png
zh_Hans: png
137 changes: 137 additions & 0 deletions api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import json
from copy import deepcopy
from typing import Any, Union

from pandas import DataFrame
from yarl import URL

from core.helper import ssrf_proxy
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.tool.builtin_tool import BuiltinTool


class NovitaAiModelQueryTool(BuiltinTool):
_model_query_endpoint = 'https://api.novita.ai/v3/model'

def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
raise ToolProviderCredentialValidationError("Novita AI API Key is required.")

api_key = self.runtime.credentials.get('api_key')
headers = {
'Content-Type': 'application/json',
'Authorization': "Bearer " + api_key
}
params = self._process_parameters(tool_parameters)
result_type = params.get('result_type')
del params['result_type']

models_data = self._query_models(
models_data=[],
headers=headers,
params=params,
recursive=False if result_type == 'first sd_name' or result_type == 'first name sd_name pair' else True
)

result_str = ''
if result_type == 'first sd_name':
result_str = models_data[0]['sd_name_in_api']
elif result_type == 'first name sd_name pair':
result_str = json.dumps({'name': models_data[0]['name'], 'sd_name': models_data[0]['sd_name_in_api']})
elif result_type == 'sd_name array':
sd_name_array = [model['sd_name_in_api'] for model in models_data]
result_str = json.dumps(sd_name_array)
elif result_type == 'name array':
name_array = [model['name'] for model in models_data]
result_str = json.dumps(name_array)
elif result_type == 'name sd_name pair array':
name_sd_name_pair_array = [{'name': model['name'], 'sd_name': model['sd_name_in_api']} for model in models_data]
result_str = json.dumps(name_sd_name_pair_array)
elif result_type == 'whole info array':
result_str = json.dumps(models_data)
else:
raise NotImplementedError

return self.create_text_message(result_str)

def _query_models(self, models_data: list, headers: dict[str, Any],
params: dict[str, Any], pagination_cursor: str = '', recursive: bool = True) -> list:
"""
query models
"""
inside_params = deepcopy(params)

if pagination_cursor != '':
inside_params['pagination.cursor'] = pagination_cursor

response = ssrf_proxy.get(
url=str(URL(self._model_query_endpoint)),
headers=headers,
params=params,
timeout=(10, 60)
)

res_data = response.json()

models_data.extend(res_data['models'])

res_data_len = len(res_data['models'])
if res_data_len == 0 or res_data_len < int(params['pagination.limit']) or recursive is False:
# deduplicate
df = DataFrame.from_dict(models_data)
df_unique = df.drop_duplicates(subset=['id'])
models_data = df_unique.to_dict('records')
return models_data

return self._query_models(
models_data=models_data,
headers=headers,
params=inside_params,
pagination_cursor=res_data['pagination']['next_cursor']
)

def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
process parameters
"""
process_parameters = deepcopy(parameters)
res_parameters = {}

# delete none or empty
keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == '']
for k in keys_to_delete:
del process_parameters[k]

if 'query' in process_parameters and process_parameters.get('query') != 'unspecified':
res_parameters['filter.query'] = process_parameters['query']

if 'visibility' in process_parameters and process_parameters.get('visibility') != 'unspecified':
res_parameters['filter.visibility'] = process_parameters['visibility']

if 'source' in process_parameters and process_parameters.get('source') != 'unspecified':
res_parameters['filter.source'] = process_parameters['source']

if 'type' in process_parameters and process_parameters.get('type') != 'unspecified':
res_parameters['filter.types'] = process_parameters['type']

if 'is_sdxl' in process_parameters:
if process_parameters['is_sdxl'] == 'true':
res_parameters['filter.is_sdxl'] = True
elif process_parameters['is_sdxl'] == 'false':
res_parameters['filter.is_sdxl'] = False

res_parameters['result_type'] = process_parameters.get('result_type', 'first sd_name')

res_parameters['pagination.limit'] = 1 \
if res_parameters.get('result_type') == 'first sd_name' \
or res_parameters.get('result_type') == 'first name sd_name pair'\
else 100

return res_parameters
Loading

0 comments on commit 132f5fb

Please sign in to comment.