-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(backend): refactor how inference providers are added and configured #475
Changes from all commits
a11b8ef
db9b895
fd379bc
fd8fc06
cef6c64
1575c13
66e3f4e
bec7836
def5e9e
ed43cc8
2e9a0d1
9e23c13
d3dd244
7b7d63b
fd36122
ed85d6d
fd5c6d3
6cab720
53e919e
a19599f
3824a7d
69168ed
edd48ce
2ee9a2c
a370f20
90f47e9
e1b561a
8cb8b0f
da5f115
26c1ca1
bdccdf8
e20311d
6339de2
0c0f457
4ea3145
5fd2727
73ce604
a8c2cfd
3aa84d9
7e41f48
7c0c898
2b03f13
f76ef90
f6d9a08
fe391ad
73b4de4
5746c2a
c9b7a6e
c5209ac
53c3361
96f1983
924ba5e
c9d757b
1104d5f
cf55869
f04b1b2
610c081
723d126
801f45a
a2c74d2
f778e7c
7ba4058
329ec43
59f8f88
0461bf1
95f2ec9
59e2885
74d8acf
749d8ea
048886e
25b39c9
770adc6
8d2600f
6f85c65
a16920c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,11 @@ | |
"[vue]": { | ||
"editor.defaultFormatter": "esbenp.prettier-vscode" | ||
}, | ||
"python.testing.pytestArgs": ["tests"], | ||
"python.testing.pytestArgs": ["tests", "backend"], | ||
"python.testing.unittestEnabled": false, | ||
"python.testing.pytestEnabled": true | ||
"python.testing.pytestEnabled": true, | ||
"sonarlint.connectedMode.project": { | ||
"connectionId": "YeagerAI", | ||
"projectKey": "yeagerai_genlayer-simulator" | ||
} | ||
} | ||
Comment on lines
+5
to
12
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extra: pytest and sonar vscode config |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ script_location = migration | |
|
||
# sys.path path, will be prepended to sys.path if present. | ||
# defaults to the current working directory. | ||
prepend_sys_path = . | ||
prepend_sys_path = ./backend/database_handler | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. needed for imports |
||
|
||
# timezone to use when rendering the date within the migration file | ||
# as well as the filename. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from backend.domain.types import LLMProvider | ||
from backend.node.create_nodes.providers import get_default_providers | ||
from .models import LLMProviderDBModel | ||
from sqlalchemy.orm import Session | ||
|
||
|
||
class LLMProviderRegistry: | ||
def __init__(self, session: Session): | ||
self.session = session | ||
|
||
def reset_defaults(self): | ||
"""Reset all providers to their default values.""" | ||
self.session.query(LLMProviderDBModel).delete() | ||
|
||
providers = get_default_providers() | ||
for provider in providers: | ||
self.session.add(_to_db_model(provider)) | ||
|
||
self.session.commit() | ||
|
||
def get_all(self) -> list[LLMProvider]: | ||
return [ | ||
_to_domain(provider) | ||
for provider in self.session.query(LLMProviderDBModel).all() | ||
] | ||
|
||
def get_all_dict(self) -> list[dict]: | ||
return [ | ||
_to_domain(provider).__dict__ | ||
for provider in self.session.query(LLMProviderDBModel).all() | ||
] | ||
|
||
def add(self, provider: LLMProvider) -> int: | ||
model = _to_db_model(provider) | ||
self.session.add(model) | ||
self.session.commit() | ||
return model.id | ||
|
||
def update(self, id: int, provider: LLMProvider): | ||
self.session.query(LLMProviderDBModel).filter( | ||
LLMProviderDBModel.id == id | ||
).update( | ||
{ | ||
LLMProviderDBModel.provider: provider.provider, | ||
LLMProviderDBModel.model: provider.model, | ||
LLMProviderDBModel.config: provider.config, | ||
LLMProviderDBModel.plugin: provider.plugin, | ||
LLMProviderDBModel.plugin_config: provider.plugin_config, | ||
} | ||
) | ||
self.session.commit() | ||
|
||
def delete(self, id: int): | ||
self.session.query(LLMProviderDBModel).filter( | ||
LLMProviderDBModel.id == id | ||
).delete() | ||
self.session.commit() | ||
|
||
|
||
def _to_domain(db_model: LLMProvider) -> LLMProvider: | ||
return LLMProvider( | ||
id=db_model.id, | ||
provider=db_model.provider, | ||
model=db_model.model, | ||
config=db_model.config, | ||
plugin=db_model.plugin, | ||
plugin_config=db_model.plugin_config, | ||
) | ||
|
||
|
||
def _to_db_model(domain: LLMProvider) -> LLMProviderDBModel: | ||
return LLMProviderDBModel( | ||
provider=domain.provider, | ||
model=domain.model, | ||
config=domain.config, | ||
plugin=domain.plugin, | ||
plugin_config=domain.plugin_config, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
"""add plugin and plugin_config to validators | ||
|
||
Revision ID: 986d9a6b0dda | ||
Revises: db38e78684a8 | ||
Create Date: 2024-09-10 14:47:10.730407 | ||
|
||
""" | ||
|
||
from typing import Sequence, Union | ||
|
||
from alembic import op | ||
from sqlalchemy import column, table | ||
import sqlalchemy as sa | ||
from sqlalchemy.dialects import postgresql | ||
|
||
from backend.node.create_nodes.providers import get_default_provider_for | ||
|
||
# revision identifiers, used by Alembic. | ||
revision: str = "986d9a6b0dda" | ||
down_revision: Union[str, None] = "db38e78684a8" | ||
branch_labels: Union[str, Sequence[str], None] = None | ||
depends_on: Union[str, Sequence[str], None] = None | ||
|
||
|
||
def upgrade() -> None: | ||
op.add_column("validators", sa.Column("plugin", sa.String(length=255))) | ||
op.add_column( | ||
"validators", | ||
sa.Column("plugin_config", postgresql.JSONB(astext_type=sa.Text())), | ||
) | ||
|
||
# Modify below | ||
|
||
# Create a table object for the validators table | ||
validators = table( | ||
"validators", | ||
column("id", sa.Integer), | ||
column("provider", sa.String), | ||
column("model", sa.String), | ||
column("plugin", sa.String), | ||
column("plugin_config", postgresql.JSONB), | ||
column("config", postgresql.JSONB), | ||
) | ||
|
||
# Fetch existing data | ||
conn = op.get_bind() | ||
results = conn.execute(validators.select()) | ||
|
||
# Process data and perform updates | ||
for validator in results: | ||
id = validator.id | ||
provider = validator.provider | ||
model = validator.model | ||
default_provider = get_default_provider_for(provider=provider, model=model) | ||
conn.execute( | ||
validators.update() | ||
.where(validators.c.id == id) | ||
.values( | ||
plugin=default_provider.plugin, | ||
plugin_config=default_provider.plugin_config, | ||
config=default_provider.config, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that we are overriding existing configs. This is because the schema has changed and it's the simplest way. Handling the logic of migrating current configs is possible, but would be a lot of work and it's probably not a requirement |
||
) | ||
) | ||
# Modify above | ||
|
||
op.alter_column( | ||
"validators", "plugin", existing_type=sa.VARCHAR(length=255), nullable=False | ||
) | ||
op.alter_column( | ||
"validators", | ||
"plugin_config", | ||
existing_type=postgresql.JSONB(astext_type=sa.Text()), | ||
nullable=False, | ||
) | ||
op.alter_column( | ||
"validators", "provider", existing_type=sa.VARCHAR(length=255), nullable=False | ||
) | ||
op.alter_column( | ||
"validators", "model", existing_type=sa.VARCHAR(length=255), nullable=False | ||
) | ||
|
||
|
||
def downgrade() -> None: | ||
op.alter_column( | ||
"validators", "model", existing_type=sa.VARCHAR(length=255), nullable=True | ||
) | ||
op.alter_column( | ||
"validators", "provider", existing_type=sa.VARCHAR(length=255), nullable=True | ||
) | ||
op.drop_column("validators", "plugin_config") | ||
op.drop_column("validators", "plugin") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extra: moved unit tests