Skip to content

Commit

Permalink
Update stable_diffusion.py (#7536)
Browse files Browse the repository at this point in the history
  • Loading branch information
flingjie authored Aug 23, 2024
1 parent e42848f commit 70d6ab0
Showing 1 changed file with 39 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"seed_resize_from_w": -1,

# Samplers
# "sampler_name": "DPM++ 2M",
"sampler_name": "DPM++ 2M",
# "scheduler": "",
# "sampler_index": "Automatic",

Expand Down Expand Up @@ -178,6 +178,23 @@ def get_sd_models(self) -> list[str]:
return [d['model_name'] for d in response.json()]
except Exception as e:
return []

def get_sample_methods(self) -> list[str]:
"""
get sample method
"""
try:
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
return []
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers')
response = get(url=api_url, timeout=(2, 10))
if response.status_code != 200:
return []
else:
return [d['name'] for d in response.json()]
except Exception as e:
return []

def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
Expand Down Expand Up @@ -339,7 +356,27 @@ def get_runtime_parameters(self) -> list[ToolParameter]:
label=I18nObject(en_US=i, zh_Hans=i)
) for i in models])
)

except:
pass


sample_methods = self.get_sample_methods()
if len(sample_methods) != 0:
parameters.append(
ToolParameter(name='sampler_name',
label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'),
human_description=I18nObject(
en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档',
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
required=True,
default=sample_methods[0],
options=[ToolParameterOption(
value=i,
label=I18nObject(en_US=i, zh_Hans=i)
) for i in sample_methods])
)
return parameters

0 comments on commit 70d6ab0

Please sign in to comment.