Skip to content

Commit

Permalink
Merge pull request #3 from gritaro/rc-0.1.3
Browse files Browse the repository at this point in the history
Add support for completion models configuration
  • Loading branch information
gritaro authored Feb 16, 2024
2 parents 1b4ceef + da89b50 commit 96582f6
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 50 deletions.
23 changes: 17 additions & 6 deletions custom_components/gigachain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""The GigaChain integration."""
from __future__ import annotations
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
Expand All @@ -9,6 +8,7 @@
template,
)
from homeassistant.components.conversation import AgentManager, agent

from typing import Literal
from langchain_community.chat_models import GigaChat, ChatYandexGPT, ChatOpenAI
from langchain.schema import AIMessage, HumanMessage, SystemMessage
Expand All @@ -17,7 +17,7 @@
DOMAIN,
CONF_ENGINE,
CONF_TEMPERATURE,
DEFAULT_CONF_TEMPERATURE,
DEFAULT_TEMPERATURE,
CONF_CHAT_MODEL,
DEFAULT_CHAT_MODEL,
CONF_CHAT_MODEL,
Expand All @@ -38,21 +38,32 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:

async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Initialize GigaChain."""
temperature = entry.options.get(CONF_TEMPERATURE, DEFAULT_CONF_TEMPERATURE)
temperature = entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
model = entry.options.get(CONF_CHAT_MODEL)
engine = entry.data.get(CONF_ENGINE) or "gigachat"
entry.async_on_unload(entry.add_update_listener(update_listener))
if engine == 'gigachat':
client = GigaChat(temperature=temperature,
model='GigaChat:latest',
model=model,
verbose=True,
credentials=entry.data[CONF_API_KEY],
verify_ssl_certs=False)
elif engine == 'yandexgpt':
client = ChatYandexGPT(temperature=temperature,
if model == "YandexGPT":
model_url = "gpt://" + entry.data[CONF_FOLDER_ID] + "/yandexgpt/latest"
elif model == 'YandexGPT Lite':
model_url = "gpt://" + entry.data[CONF_FOLDER_ID] + "/yandexgpt-lite/latest"
elif model == 'Summary':
model_url = "gpt://" + entry.data[CONF_FOLDER_ID] + "/summarization/latest"
else:
model_url = ""
client = ChatYandexGPT(
model_uri=model_url,
temperature=temperature,
api_key=entry.data[CONF_API_KEY],
folder_id = entry.data[CONF_FOLDER_ID])
else:
client = ChatOpenAI(model="gpt-3.5-turbo",
client = ChatOpenAI(model=model,
temperature=temperature,
openai_api_key=entry.data[CONF_API_KEY])
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
Expand Down
55 changes: 24 additions & 31 deletions custom_components/gigachain/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
CONF_TEMPERATURE,
CONF_ENGINE_OPTIONS,
CONF_PROMPT,
CONF_MAX_TKNS,
DEFAULT_CONF_TEMPERATURE,
DEFAULT_CONF_MAX_TKNS,
CONF_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_MODELS,
DEFAULT_MAX_TOKENS,
DEFAULT_CHAT_MODEL,
DEFAULT_PROMPT,
UNIQUE_ID,
UNIQUE_ID
)

STEP_USER_SCHEMA = vol.Schema(
Expand All @@ -42,7 +43,7 @@
}
)

STEP_GIGACHAT_SCHEMA = vol.Schema(
STEP_API_KEY_SCHEMA = vol.Schema(
{
vol.Required(CONF_API_KEY): str
}
Expand All @@ -53,16 +54,11 @@
vol.Required(CONF_FOLDER_ID): str
}
)
STEP_OPENAI_SCHEMA = vol.Schema(
{
vol.Required(CONF_API_KEY): str
}
)

ENGINE_SCHEMA = {
"gigachat": STEP_GIGACHAT_SCHEMA,
"gigachat": STEP_API_KEY_SCHEMA,
"yandexgpt": STEP_YANDEXGPT_SCHEMA,
"openai": STEP_OPENAI_SCHEMA
"openai": STEP_API_KEY_SCHEMA
}

DEFAULT_OPTIONS = types.MappingProxyType(
Expand Down Expand Up @@ -136,47 +132,44 @@ async def async_step_init(
"""Manage the options."""
if user_input is not None:
return self.async_create_entry(title=self.config_entry.unique_id, data=user_input)
schema = common_config_option_schema(self.config_entry.options)
schema = common_config_option_schema(self.config_entry.unique_id, self.config_entry.options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
)

def common_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
def common_config_option_schema(unique_id: str, options: MappingProxyType[str, Any]) -> dict:
"""Return a schema for GigaChain completion options."""
if not options:
options = DEFAULT_OPTIONS
return {
vol.Optional(
CONF_CHAT_MODEL,
description={
"suggested_value": options.get(CONF_CHAT_MODEL),
}, default="none",
): selector.SelectSelector(
selector.SelectSelectorConfig(options=DEFAULT_MODELS[unique_id]),
),
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options[CONF_PROMPT]},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_CHAT_MODEL,
description={
# New key in HA 2023.4
"suggested_value": options.get(CONF_CHAT_MODEL,
DEFAULT_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional(
CONF_TEMPERATURE,
description={
# New key in HA 2023.4
"suggested_value": options.get(CONF_TEMPERATURE,
DEFAULT_CONF_TEMPERATURE)
DEFAULT_TEMPERATURE)
},
default=DEFAULT_CONF_TEMPERATURE,
default=DEFAULT_TEMPERATURE,
): float,
vol.Optional(
CONF_MAX_TKNS,
CONF_MAX_TOKENS,
description={
# New key in HA 2023.4
"suggested_value": options.get(CONF_MAX_TKNS,
DEFAULT_CONF_MAX_TKNS)
"suggested_value": options.get(CONF_MAX_TOKENS,
DEFAULT_MAX_TOKENS)
},
default=DEFAULT_CONF_MAX_TKNS,
default=DEFAULT_MAX_TOKENS,
): int,
}
42 changes: 32 additions & 10 deletions custom_components/gigachain/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,38 @@
from homeassistant.helpers import selector

DOMAIN = "gigachain"
CONF_ENGINE = "engine"
UNIQUE_ID = {"gigachat": "GigaChat", "yandexgpt": "YandexGPT", "openai": "OpenAI"}

ID_GIGACHAT = "gigachat"
ID_YANDEX_GPT = "yandexgpt"
ID_OPENAI = "openai"
UNIQUE_ID_GIGACHAT = "GigaChat"
UNIQUE_ID_YANDEX_GPT = "YandexGPT"
UNIQUE_ID_OPENAI = "OpenAI"

UNIQUE_ID = {
ID_GIGACHAT: UNIQUE_ID_GIGACHAT,
ID_YANDEX_GPT: UNIQUE_ID_YANDEX_GPT,
ID_OPENAI: UNIQUE_ID_OPENAI
}

CONF_ENGINE_OPTIONS = [
selector.SelectOptionDict(value="gigachat", label="GigaChat"),
selector.SelectOptionDict(value="yandexgpt", label="YandexGPT"),
selector.SelectOptionDict(value="openai", label="OpenAI"),
selector.SelectOptionDict(value=ID_GIGACHAT, label=UNIQUE_ID_GIGACHAT),
selector.SelectOptionDict(value=ID_YANDEX_GPT, label=UNIQUE_ID_YANDEX_GPT),
selector.SelectOptionDict(value=ID_OPENAI, label=UNIQUE_ID_OPENAI),
]
DEFAULT_MODELS_GIGACHAT = [
"GigaChat", "GigaChat:latest", "GigaChat-Plus", "GigaChat-Pro"
]
DEFAULT_MODELS_YANDEX_GPT = ["YandexGPT", "YandexGPT Lite", "Summary"]
DEFAULT_MODELS_OPENAI = ["gpt-3.5-turbo"]
DEFAULT_MODELS = {
UNIQUE_ID_GIGACHAT: DEFAULT_MODELS_GIGACHAT,
UNIQUE_ID_YANDEX_GPT: DEFAULT_MODELS_YANDEX_GPT,
UNIQUE_ID_OPENAI: DEFAULT_MODELS_OPENAI
}
CONF_API_KEY = "api_key"
CONF_FOLDER_ID = "folder_id"
CONF_ENGINE = "engine"

CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """Ты HAL 9000, компьютер из цикла произведений «Космическая одиссея» Артура Кларка, обладающий способностью к самообучению.
Expand All @@ -33,9 +56,8 @@
"""

CONF_CHAT_MODEL = "model"
#GigaChat-Plus,GigaChat-Pro,GigaChat:latest
DEFAULT_CHAT_MODEL = "GigaChat"
DEFAULT_CHAT_MODEL = ""
CONF_TEMPERATURE = "temperature"
DEFAULT_CONF_TEMPERATURE = "0.1"
CONF_MAX_TKNS = "max_tokens"
DEFAULT_CONF_MAX_TKNS = "250"
DEFAULT_TEMPERATURE = 0.1
CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 250
4 changes: 1 addition & 3 deletions custom_components/gigachain/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
"iot_class": "cloud_polling",
"issue_tracker": "https://github.com/gritaro/gigachain/issues",
"requirements": [
"gigachat==0.1.16",
"langchain==0.1.7",
"gigachain-community==0.0.16",
"gigachain==0.1.4",
"yandexcloud==0.259.0"
],
"version": "0.1.2"
Expand Down
17 changes: 17 additions & 0 deletions test-model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from langchain.schema import HumanMessage, SystemMessage
from langchain_community.chat_models import ChatAnyscale

chat = ChatAnyscale(model="eta-llama/Llama-2-70b-chat-hf", anyscale_api_key="<key>")

messages = [
SystemMessage(
content="You are a helpful AI that shares everything you know."
)
]

while(True):
user_input = input("User: ")
messages.append(HumanMessage(content=user_input))
res = chat(messages)
messages.append(res)
print("Bot: ", res.content)

0 comments on commit 96582f6

Please sign in to comment.