From a11b8ef69a4d2a6c0bc9165c055a2140256fd37c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Tue, 3 Sep 2024 14:10:04 -0300 Subject: [PATCH 01/75] add jsf for json schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../node/create_nodes/providers_schema.json | 19 +++++++++++++++++++ .../node/create_nodes/test_create_nodes.py | 15 ++++++++++++++- backend/protocol_rpc/requirements.txt | 1 + 3 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 backend/node/create_nodes/providers_schema.json diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json new file mode 100644 index 000000000..637788f74 --- /dev/null +++ b/backend/node/create_nodes/providers_schema.json @@ -0,0 +1,19 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "providers", + "type": "object", + "enum": [ + { + "type": "admin", + "permissions": ["read", "write", "delete"] + }, + { + "type": "user", + "permissions": ["read", "write"] + }, + { + "type": "guest", + "permissions": ["read"] + } + ] +} diff --git a/backend/node/create_nodes/test_create_nodes.py b/backend/node/create_nodes/test_create_nodes.py index c5da645ca..a352bc065 100644 --- a/backend/node/create_nodes/test_create_nodes.py +++ b/backend/node/create_nodes/test_create_nodes.py @@ -1,4 +1,5 @@ -from backend.node.create_nodes.create_nodes import ( +import os +from create_nodes import ( get_default_config_for_providers_and_nodes, ) @@ -10,3 +11,15 @@ def test_get_default_config_for_providers_and_nodes(): assert "providers" in out assert "provider_weights" in out["providers"] assert "node_defaults" in out + + +def test(): + from jsf import JSF + + current_directory = os.path.dirname(os.path.abspath(__file__)) + schema_file = os.path.join(current_directory, "providers_schema.json") + faker = JSF.from_json(schema_file) + + fake_json = faker.generate() + + print(fake_json) diff --git a/backend/protocol_rpc/requirements.txt b/backend/protocol_rpc/requirements.txt index 0ebb02b76..1055c5e40 100644 --- a/backend/protocol_rpc/requirements.txt +++ b/backend/protocol_rpc/requirements.txt @@ -19,3 +19,4 @@ eth-account==0.13.1 eth-utils==4.1.1 sentence-transformers==3.0.1 Flask-SQLAlchemy==3.1.1 +jsf==0.11.2 From db9b8951242c1b608653ce44d614a44a380f0296 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Tue, 3 Sep 2024 14:10:15 -0300 Subject: [PATCH 02/75] configure pytest vscode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .vscode/settings.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 45455359c..e65f2fff9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,7 +2,7 @@ "[vue]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, - "python.testing.pytestArgs": ["tests"], + "python.testing.pytestArgs": ["tests", "backend"], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true } From fd379bc66cfd70b26d7ca3e0a25008966422cb05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Tue, 3 Sep 2024 14:10:21 -0300 Subject: [PATCH 03/75] configure dockerignore MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .dockerignore | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.dockerignore b/.dockerignore index cc2a0d5f9..4ac40cdf0 100644 --- a/.dockerignore +++ b/.dockerignore @@ -22,3 +22,8 @@ frontend/src/assets/examples frontend/.nyc_output frontend/coverage frontend/vite.config.ts.timestamp-* + +# Python +# Byte-compiled / optimized / DLL files +**/__pycache__/ +**/*.py[cod] From fd8fc0670e89c5a02a84f8dd2d5375518c2d8285 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Tue, 3 Sep 2024 14:41:14 -0300 Subject: [PATCH 04/75] improve schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../node/create_nodes/providers_schema.json | 91 +++++++++++++++++-- .../node/create_nodes/test_create_nodes.py | 3 +- 2 files changed, 84 insertions(+), 10 deletions(-) diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 637788f74..c80854b98 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -1,19 +1,92 @@ { "$schema": "https://json-schema.org/draft/2020-12/schema", - "title": "providers", + "title": "Provider", "type": "object", - "enum": [ + "properties": { + "provider": { + "enum": ["heuristai", "openai", "ollama"] + }, + "model": { + "type": "string" + }, + + "config": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { "const": "string" }, + "pattern": { "type": "string" } + }, + "required": ["type", "pattern"], + "additionalProperties": false + }, + { + "type": "number", + "properties": { + "type": { "const": "number" }, + "minimum": { "type": "number" }, + "maximum": { "type": "number" } + }, + "required": ["type", "minimum", "maximum"], + "additionalProperties": false + } + ] + } + }, + "allOf": [ { - "type": "admin", - "permissions": ["read", "write", "delete"] + "if": { + "properties": { + "provider": { "const": "ollama" } + }, + "required": ["provider"] + }, + + "then": { + "properties": { + "model": { + "enum": ["llama3", "mistral", "gemma"] + } + } + } }, { - "type": "user", - "permissions": ["read", "write"] + "if": { + "properties": { + "provider": { "const": "heuristai" } + }, + "required": ["provider"] + }, + "then": { + "properties": { + "model": { + "enum": [ + "mistralai/mixtral-8x7b-instruct", + "meta-llama/llama-2-70b-chat", + "openhermes-2-yi-34b-gptq", + "dolphin-2.9-llama3-8b" + ] + } + } + } }, { - "type": "guest", - "permissions": ["read"] + "if": { + "properties": { + "provider": { "const": "openai" } + }, + "required": ["provider"] + }, + "then": { + "properties": { + "model": { + "enum": ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4o-mini"] + } + } + } } - ] + ], + "required": ["provider", "config", "model"], + "additionalProperties": false } diff --git a/backend/node/create_nodes/test_create_nodes.py b/backend/node/create_nodes/test_create_nodes.py index a352bc065..4f2ec585f 100644 --- a/backend/node/create_nodes/test_create_nodes.py +++ b/backend/node/create_nodes/test_create_nodes.py @@ -21,5 +21,6 @@ def test(): faker = JSF.from_json(schema_file) fake_json = faker.generate() + from pprint import pprint - print(fake_json) + pprint(fake_json) From cef6c641b36cce759ee81118c3dfc823f54a615b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Tue, 3 Sep 2024 20:30:20 -0300 Subject: [PATCH 05/75] test hypothesis_jsonschema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../node/create_nodes/providers_schema.json | 138 +++++++++--------- .../node/create_nodes/test_create_nodes.py | 36 ++++- backend/protocol_rpc/requirements.txt | 1 + 3 files changed, 95 insertions(+), 80 deletions(-) diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index c80854b98..2851b739f 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -2,91 +2,83 @@ "$schema": "https://json-schema.org/draft/2020-12/schema", "title": "Provider", "type": "object", - "properties": { - "provider": { - "enum": ["heuristai", "openai", "ollama"] - }, - "model": { - "type": "string" - }, - "config": { - "oneOf": [ - { - "type": "object", - "properties": { - "type": { "const": "string" }, - "pattern": { "type": "string" } - }, - "required": ["type", "pattern"], - "additionalProperties": false - }, - { - "type": "number", - "properties": { - "type": { "const": "number" }, - "minimum": { "type": "number" }, - "maximum": { "type": "number" } - }, - "required": ["type", "minimum", "maximum"], - "additionalProperties": false - } - ] - } + "properties": { + "provider": { "type": "string" }, + "model": { "type": "string" } }, - "allOf": [ + + "oneOf": [ { - "if": { - "properties": { - "provider": { "const": "ollama" } - }, - "required": ["provider"] + "properties": { + "provider": { "const": "ollama" }, + "model": { "const": "llama3" } }, - - "then": { - "properties": { - "model": { - "enum": ["llama3", "mistral", "gemma"] - } - } + "required": ["provider", "model"] + }, + { + "properties": { + "provider": { "const": "ollama" }, + "model": { "const": "mistral" } } }, { - "if": { - "properties": { - "provider": { "const": "heuristai" } - }, - "required": ["provider"] - }, - "then": { - "properties": { - "model": { - "enum": [ - "mistralai/mixtral-8x7b-instruct", - "meta-llama/llama-2-70b-chat", - "openhermes-2-yi-34b-gptq", - "dolphin-2.9-llama3-8b" - ] - } - } + "properties": { + "provider": { "const": "ollama" }, + "model": { "const": "gemma" } } }, { - "if": { - "properties": { - "provider": { "const": "openai" } - }, - "required": ["provider"] - }, - "then": { - "properties": { - "model": { - "enum": ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4o-mini"] - } - } + "properties": { + "provider": { "const": "heuristai" }, + "model": { "const": "mistralai/mixtral-8x7b-instruct" } + } + }, + { + "properties": { + "provider": { "const": "heuristai" }, + "model": { "const": "meta-llama/llama-2-70b-chat" } + } + }, + { + "properties": { + "provider": { "const": "heuristai" }, + "model": { "const": "openhermes-2-yi-34b-gptq" } + } + }, + { + "properties": { + "provider": { "const": "heuristai" }, + "model": { "const": "dolphin-2.9-llama3-8b" } + } + }, + { + "properties": { + "provider": { "const": "openai" }, + "model": { "const": "gpt-3.5-turbo" } + } + }, + { + "properties": { + "provider": { "const": "openai" }, + "model": { "const": "gpt-4" } + } + }, + { + "properties": { + "provider": { "const": "openai" }, + "model": { "const": "gpt-4o" } + } + }, + { + "properties": { + "provider": { "const": "openai" }, + "model": { "const": "gpt-4o-mini" } } } ], - "required": ["provider", "config", "model"], + + "required": ["provider", "model"], + "additionalProperties": false } diff --git a/backend/node/create_nodes/test_create_nodes.py b/backend/node/create_nodes/test_create_nodes.py index 4f2ec585f..89805e14b 100644 --- a/backend/node/create_nodes/test_create_nodes.py +++ b/backend/node/create_nodes/test_create_nodes.py @@ -4,7 +4,7 @@ ) -def test_get_default_config_for_providers_and_nodes(): +def old_test_get_default_config_for_providers_and_nodes(): out = get_default_config_for_providers_and_nodes() assert isinstance(out, dict) @@ -13,14 +13,36 @@ def test_get_default_config_for_providers_and_nodes(): assert "node_defaults" in out -def test(): +current_directory = os.path.dirname(os.path.abspath(__file__)) +schema_file = os.path.join(current_directory, "providers_schema.json") + +import json + +from pprint import pprint + +with open(schema_file, "r") as f: + schema = json.loads(f.read()) + + +from hypothesis import given +from hypothesis_jsonschema import from_schema + + +@given(from_schema(schema)) +def test1(value): + pprint(value) + + +def fadstest(): + # TODO: https://github.com/json-schema-faker/json-schema-faker/tree/master/docs is better at generating fake data. Can we run JavaScript in Python? + # TODO: test https://github.com/python-jsonschema/hypothesis-jsonschema + from jsf import JSF + from jsonschema import validate - current_directory = os.path.dirname(os.path.abspath(__file__)) - schema_file = os.path.join(current_directory, "providers_schema.json") faker = JSF.from_json(schema_file) - fake_json = faker.generate() - from pprint import pprint - pprint(fake_json) + + pprint(from_schema(schema).example()) + # validate(instance=fake_json, schema=schema) diff --git a/backend/protocol_rpc/requirements.txt b/backend/protocol_rpc/requirements.txt index 1055c5e40..183337fa4 100644 --- a/backend/protocol_rpc/requirements.txt +++ b/backend/protocol_rpc/requirements.txt @@ -20,3 +20,4 @@ eth-utils==4.1.1 sentence-transformers==3.0.1 Flask-SQLAlchemy==3.1.1 jsf==0.11.2 +jsonschema==4.23.0 From 1575c13bef6bcd70bd0853bce30ee1a05cdf52ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 4 Sep 2024 08:19:29 -0300 Subject: [PATCH 06/75] test other libraries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../node/create_nodes/providers_schema.json | 112 ++++++++---------- .../node/create_nodes/test_create_nodes.py | 27 +++-- 2 files changed, 66 insertions(+), 73 deletions(-) diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 2851b739f..fd3931d21 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -2,83 +2,67 @@ "$schema": "https://json-schema.org/draft/2020-12/schema", "title": "Provider", "type": "object", - "properties": { - "provider": { "type": "string" }, - "model": { "type": "string" } + "provider": { + "enum": ["heuristai", "openai", "ollama"] + }, + "model": { + "type": "string" + } }, - - "oneOf": [ + "allOf": [ { - "properties": { - "provider": { "const": "ollama" }, - "model": { "const": "llama3" } + "if": { + "properties": { + "provider": { "const": "ollama" } + }, + "required": ["provider"] }, - "required": ["provider", "model"] - }, - { - "properties": { - "provider": { "const": "ollama" }, - "model": { "const": "mistral" } - } - }, - { - "properties": { - "provider": { "const": "ollama" }, - "model": { "const": "gemma" } - } - }, - { - "properties": { - "provider": { "const": "heuristai" }, - "model": { "const": "mistralai/mixtral-8x7b-instruct" } - } - }, - { - "properties": { - "provider": { "const": "heuristai" }, - "model": { "const": "meta-llama/llama-2-70b-chat" } - } - }, - { - "properties": { - "provider": { "const": "heuristai" }, - "model": { "const": "openhermes-2-yi-34b-gptq" } - } - }, - { - "properties": { - "provider": { "const": "heuristai" }, - "model": { "const": "dolphin-2.9-llama3-8b" } - } - }, - { - "properties": { - "provider": { "const": "openai" }, - "model": { "const": "gpt-3.5-turbo" } - } - }, - { - "properties": { - "provider": { "const": "openai" }, - "model": { "const": "gpt-4" } + + "then": { + "properties": { + "model": { + "enum": ["llama3", "mistral", "gemma"] + } + } } }, { - "properties": { - "provider": { "const": "openai" }, - "model": { "const": "gpt-4o" } + "if": { + "properties": { + "provider": { "const": "heuristai" } + }, + "required": ["provider"] + }, + "then": { + "properties": { + "model": { + "enum": [ + "mistralai/mixtral-8x7b-instruct", + "meta-llama/llama-2-70b-chat", + "openhermes-2-yi-34b-gptq", + "dolphin-2.9-llama3-8b" + ] + } + } } }, { - "properties": { - "provider": { "const": "openai" }, - "model": { "const": "gpt-4o-mini" } + "if": { + "properties": { + "provider": { "const": "openai" } + }, + "required": ["provider"] + }, + "then": { + "properties": { + "model": { + "enum": ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4o-mini"] + } + } } } ], - "required": ["provider", "model"], - "additionalProperties": false } diff --git a/backend/node/create_nodes/test_create_nodes.py b/backend/node/create_nodes/test_create_nodes.py index 89805e14b..99a56c379 100644 --- a/backend/node/create_nodes/test_create_nodes.py +++ b/backend/node/create_nodes/test_create_nodes.py @@ -24,25 +24,34 @@ def old_test_get_default_config_for_providers_and_nodes(): schema = json.loads(f.read()) -from hypothesis import given +from hypothesis.errors import NonInteractiveExampleWarning from hypothesis_jsonschema import from_schema +from jsonschema import validate -@given(from_schema(schema)) -def test1(value): +# @given(from_schema(schema)) +def test1(): + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", NonInteractiveExampleWarning) + value = from_schema(schema).example() pprint(value) + validate(instance=value, schema=schema) + print() + print("Finished validating") + print() def fadstest(): # TODO: https://github.com/json-schema-faker/json-schema-faker/tree/master/docs is better at generating fake data. Can we run JavaScript in Python? # TODO: test https://github.com/python-jsonschema/hypothesis-jsonschema + pass + # from jsf import JSF - from jsf import JSF - from jsonschema import validate - - faker = JSF.from_json(schema_file) + # faker = JSF.from_json(schema_file) - pprint(fake_json) + # pprint(fake_json) - pprint(from_schema(schema).example()) + # pprint(from_schema(schema).example()) # validate(instance=fake_json, schema=schema) From 66e3f4ee1e2e9042e96d39c972b357e8a976700c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 4 Sep 2024 09:25:53 -0300 Subject: [PATCH 07/75] add default providers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../heuristai_dolphin-2.9-llama3-8b.json | 4 ++++ .../heuristai_meta-llamallama-2-70b-chat.json | 4 ++++ ...uristai_mistralaimixtral-8x7b-instruct.json | 4 ++++ .../heuristai_openhermes-2-yi-34b-gptq.json | 4 ++++ .../default_providers/ollama_gemma.json | 4 ++++ .../default_providers/ollama_llama3.json | 4 ++++ .../default_providers/ollama_mistral.json | 4 ++++ .../openai_gpt-3.5-turbo.json | 4 ++++ .../default_providers/openai_gpt-4.json | 4 ++++ .../default_providers/openai_gpt-4o-mini.json | 4 ++++ .../default_providers/openai_gpt-4o.json | 4 ++++ backend/node/create_nodes/test_create_nodes.py | 18 ++++++++++++++++++ 12 files changed, 62 insertions(+) create mode 100644 backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json create mode 100644 backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json create mode 100644 backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json create mode 100644 backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json create mode 100644 backend/node/create_nodes/default_providers/ollama_gemma.json create mode 100644 backend/node/create_nodes/default_providers/ollama_llama3.json create mode 100644 backend/node/create_nodes/default_providers/ollama_mistral.json create mode 100644 backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json create mode 100644 backend/node/create_nodes/default_providers/openai_gpt-4.json create mode 100644 backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json create mode 100644 backend/node/create_nodes/default_providers/openai_gpt-4o.json diff --git a/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json new file mode 100644 index 000000000..ecaf04ea9 --- /dev/null +++ b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json @@ -0,0 +1,4 @@ +{ + "provider": "heuristai", + "model": "dolphin-2.9-llama3-8b" +} diff --git a/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json new file mode 100644 index 000000000..20b7b3ba8 --- /dev/null +++ b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json @@ -0,0 +1,4 @@ +{ + "provider": "heuristai", + "model": "meta-llama/llama-2-70b-chat" +} diff --git a/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json new file mode 100644 index 000000000..89abba0fb --- /dev/null +++ b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json @@ -0,0 +1,4 @@ +{ + "provider": "heuristai", + "model": "mistralai/mixtral-8x7b-instruct" +} diff --git a/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json new file mode 100644 index 000000000..4955154ef --- /dev/null +++ b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json @@ -0,0 +1,4 @@ +{ + "provider": "heuristai", + "model": "openhermes-2-yi-34b-gptq" +} diff --git a/backend/node/create_nodes/default_providers/ollama_gemma.json b/backend/node/create_nodes/default_providers/ollama_gemma.json new file mode 100644 index 000000000..7ab0250ff --- /dev/null +++ b/backend/node/create_nodes/default_providers/ollama_gemma.json @@ -0,0 +1,4 @@ +{ + "provider": "ollama", + "model": "gemma" +} diff --git a/backend/node/create_nodes/default_providers/ollama_llama3.json b/backend/node/create_nodes/default_providers/ollama_llama3.json new file mode 100644 index 000000000..770958ce6 --- /dev/null +++ b/backend/node/create_nodes/default_providers/ollama_llama3.json @@ -0,0 +1,4 @@ +{ + "provider": "ollama", + "model": "llama3" +} diff --git a/backend/node/create_nodes/default_providers/ollama_mistral.json b/backend/node/create_nodes/default_providers/ollama_mistral.json new file mode 100644 index 000000000..3a471f422 --- /dev/null +++ b/backend/node/create_nodes/default_providers/ollama_mistral.json @@ -0,0 +1,4 @@ +{ + "provider": "ollama", + "model": "mistral" +} diff --git a/backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json b/backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json new file mode 100644 index 000000000..753cfd935 --- /dev/null +++ b/backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json @@ -0,0 +1,4 @@ +{ + "provider": "openai", + "model": "gpt-3.5-turbo" +} diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4.json b/backend/node/create_nodes/default_providers/openai_gpt-4.json new file mode 100644 index 000000000..5fbd7ad26 --- /dev/null +++ b/backend/node/create_nodes/default_providers/openai_gpt-4.json @@ -0,0 +1,4 @@ +{ + "provider": "openai", + "model": "gpt-4" +} diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json b/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json new file mode 100644 index 000000000..2a8a5a2fd --- /dev/null +++ b/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json @@ -0,0 +1,4 @@ +{ + "provider": "openai", + "model": "gpt-4o-mini" +} diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4o.json b/backend/node/create_nodes/default_providers/openai_gpt-4o.json new file mode 100644 index 000000000..9a58a2d30 --- /dev/null +++ b/backend/node/create_nodes/default_providers/openai_gpt-4o.json @@ -0,0 +1,4 @@ +{ + "provider": "openai", + "model": "gpt-4o" +} diff --git a/backend/node/create_nodes/test_create_nodes.py b/backend/node/create_nodes/test_create_nodes.py index 99a56c379..b0cd5bea7 100644 --- a/backend/node/create_nodes/test_create_nodes.py +++ b/backend/node/create_nodes/test_create_nodes.py @@ -43,6 +43,24 @@ def test1(): print() +def test_default_providers_valid(): + default_providers_folder = os.path.join(current_directory, "default_providers") + + files = [ + os.path.join(default_providers_folder, filename) + for filename in os.listdir(default_providers_folder) + if filename.endswith(".json") + ] + + assert len(files) > 0 + + for file in files: + with open(file, "r") as f: + provider = json.loads(f.read()) + pprint(provider) + validate(instance=provider, schema=schema) + + def fadstest(): # TODO: https://github.com/json-schema-faker/json-schema-faker/tree/master/docs is better at generating fake data. Can we run JavaScript in Python? # TODO: test https://github.com/python-jsonschema/hypothesis-jsonschema From bec78363664e93dab3ec6b9ea798bc90753b8d19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 4 Sep 2024 09:57:41 -0300 Subject: [PATCH 08/75] move tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/providers.py | 34 +++++++++ .../node/create_nodes/test_create_nodes.py | 75 ------------------- tests/unit/__init__.py | 0 tests/unit/test_create_nodes.py | 36 +++++++++ 4 files changed, 70 insertions(+), 75 deletions(-) create mode 100644 backend/node/create_nodes/providers.py delete mode 100644 backend/node/create_nodes/test_create_nodes.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_create_nodes.py diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py new file mode 100644 index 000000000..d51c5c2dc --- /dev/null +++ b/backend/node/create_nodes/providers.py @@ -0,0 +1,34 @@ +import os +from typing import List +from jsonschema import validate + + +# TODO: cast providers into some kind of class. We know that all providers have `provider` and `model` keys +def get_default_providers() -> List[dict]: + current_directory = os.path.dirname(os.path.abspath(__file__)) + schema_file = os.path.join(current_directory, "providers_schema.json") + + import json + + from pprint import pprint + + with open(schema_file, "r") as f: + schema = json.loads(f.read()) + + default_providers_folder = os.path.join(current_directory, "default_providers") + + files = [ + os.path.join(default_providers_folder, filename) + for filename in os.listdir(default_providers_folder) + if filename.endswith(".json") + ] + + providers = [] + for file in files: + with open(file, "r") as f: + providers.append(json.loads(f.read())) + + for provider in providers: + validate(instance=provider, schema=schema) + + return providers diff --git a/backend/node/create_nodes/test_create_nodes.py b/backend/node/create_nodes/test_create_nodes.py deleted file mode 100644 index b0cd5bea7..000000000 --- a/backend/node/create_nodes/test_create_nodes.py +++ /dev/null @@ -1,75 +0,0 @@ -import os -from create_nodes import ( - get_default_config_for_providers_and_nodes, -) - - -def old_test_get_default_config_for_providers_and_nodes(): - out = get_default_config_for_providers_and_nodes() - - assert isinstance(out, dict) - assert "providers" in out - assert "provider_weights" in out["providers"] - assert "node_defaults" in out - - -current_directory = os.path.dirname(os.path.abspath(__file__)) -schema_file = os.path.join(current_directory, "providers_schema.json") - -import json - -from pprint import pprint - -with open(schema_file, "r") as f: - schema = json.loads(f.read()) - - -from hypothesis.errors import NonInteractiveExampleWarning -from hypothesis_jsonschema import from_schema -from jsonschema import validate - - -# @given(from_schema(schema)) -def test1(): - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", NonInteractiveExampleWarning) - value = from_schema(schema).example() - pprint(value) - validate(instance=value, schema=schema) - print() - print("Finished validating") - print() - - -def test_default_providers_valid(): - default_providers_folder = os.path.join(current_directory, "default_providers") - - files = [ - os.path.join(default_providers_folder, filename) - for filename in os.listdir(default_providers_folder) - if filename.endswith(".json") - ] - - assert len(files) > 0 - - for file in files: - with open(file, "r") as f: - provider = json.loads(f.read()) - pprint(provider) - validate(instance=provider, schema=schema) - - -def fadstest(): - # TODO: https://github.com/json-schema-faker/json-schema-faker/tree/master/docs is better at generating fake data. Can we run JavaScript in Python? - # TODO: test https://github.com/python-jsonschema/hypothesis-jsonschema - pass - # from jsf import JSF - - # faker = JSF.from_json(schema_file) - - # pprint(fake_json) - - # pprint(from_schema(schema).example()) - # validate(instance=fake_json, schema=schema) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/test_create_nodes.py b/tests/unit/test_create_nodes.py new file mode 100644 index 000000000..863a971ab --- /dev/null +++ b/tests/unit/test_create_nodes.py @@ -0,0 +1,36 @@ +import os +from backend.node.create_nodes.providers import get_default_providers + + +def old_test_get_default_config_for_providers_and_nodes(): + out = get_default_config_for_providers_and_nodes() + + assert isinstance(out, dict) + assert "providers" in out + assert "provider_weights" in out["providers"] + assert "node_defaults" in out + + +# from hypothesis.errors import NonInteractiveExampleWarning +# from hypothesis_jsonschema import from_schema +# from jsonschema import validate + + +# @given(from_schema(schema)) +# def test1(): +# import warnings + +# with warnings.catch_warnings(): +# warnings.simplefilter("ignore", NonInteractiveExampleWarning) +# value = from_schema(schema).example() +# pprint(value) +# validate(instance=value, schema=schema) +# print() +# print("Finished validating") +# print() + + +def test_default_providers_valid(): + providers = get_default_providers() + + assert len(providers) > 0 From def5e9e3fa02e42b20d7dc4f76556f6f209ac2fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 4 Sep 2024 10:11:30 -0300 Subject: [PATCH 09/75] refactor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/providers.py | 37 ++++++++++++++++++++------ tests/unit/test_create_nodes.py | 36 ------------------------- tests/unit/test_providers.py | 19 +++++++++++++ 3 files changed, 48 insertions(+), 44 deletions(-) delete mode 100644 tests/unit/test_create_nodes.py create mode 100644 tests/unit/test_providers.py diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index d51c5c2dc..334db131e 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -1,21 +1,30 @@ +import json import os -from typing import List +import warnings +from typing import List, TypeAlias + +from hypothesis.errors import NonInteractiveExampleWarning +from hypothesis_jsonschema import from_schema from jsonschema import validate +current_directory = os.path.dirname(os.path.abspath(__file__)) +schema_file = os.path.join(current_directory, "providers_schema.json") +default_providers_folder = os.path.join(current_directory, "default_providers") -# TODO: cast providers into some kind of class. We know that all providers have `provider` and `model` keys -def get_default_providers() -> List[dict]: - current_directory = os.path.dirname(os.path.abspath(__file__)) - schema_file = os.path.join(current_directory, "providers_schema.json") +Provider: TypeAlias = dict - import json - from pprint import pprint +def get_schema() -> dict: with open(schema_file, "r") as f: schema = json.loads(f.read()) - default_providers_folder = os.path.join(current_directory, "default_providers") + return schema + + +# TODO: cast providers into some kind of class. We know that all providers have `provider` and `model` keys +def get_default_providers() -> List[Provider]: + schema = get_schema() files = [ os.path.join(default_providers_folder, filename) @@ -32,3 +41,15 @@ def get_default_providers() -> List[dict]: validate(instance=provider, schema=schema) return providers + + +def get_random_provider() -> Provider: + schema = get_schema() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", NonInteractiveExampleWarning) + value = from_schema(schema).example() + + validate(instance=value, schema=schema) + + return value diff --git a/tests/unit/test_create_nodes.py b/tests/unit/test_create_nodes.py deleted file mode 100644 index 863a971ab..000000000 --- a/tests/unit/test_create_nodes.py +++ /dev/null @@ -1,36 +0,0 @@ -import os -from backend.node.create_nodes.providers import get_default_providers - - -def old_test_get_default_config_for_providers_and_nodes(): - out = get_default_config_for_providers_and_nodes() - - assert isinstance(out, dict) - assert "providers" in out - assert "provider_weights" in out["providers"] - assert "node_defaults" in out - - -# from hypothesis.errors import NonInteractiveExampleWarning -# from hypothesis_jsonschema import from_schema -# from jsonschema import validate - - -# @given(from_schema(schema)) -# def test1(): -# import warnings - -# with warnings.catch_warnings(): -# warnings.simplefilter("ignore", NonInteractiveExampleWarning) -# value = from_schema(schema).example() -# pprint(value) -# validate(instance=value, schema=schema) -# print() -# print("Finished validating") -# print() - - -def test_default_providers_valid(): - providers = get_default_providers() - - assert len(providers) > 0 diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py new file mode 100644 index 000000000..26b329631 --- /dev/null +++ b/tests/unit/test_providers.py @@ -0,0 +1,19 @@ +from backend.node.create_nodes.providers import ( + get_default_providers, + get_random_provider, +) + + +def test_default_providers_valid(): + providers = get_default_providers() + + assert len(providers) > 0 + + +# Takes too long to run +# def test_get_random_provider(): +# provider = get_random_provider() + +# assert provider is not None +# assert "provider" in provider +# assert "model" in provider From ed43cc88abc7622d0fec808fa507006cdfb2758c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 4 Sep 2024 10:40:58 -0300 Subject: [PATCH 10/75] improve schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/providers.py | 12 +- .../node/create_nodes/providers_schema.json | 130 +++++++++++++++++- 2 files changed, 131 insertions(+), 11 deletions(-) diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index 334db131e..de4b3a5ec 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -35,10 +35,14 @@ def get_default_providers() -> List[Provider]: providers = [] for file in files: with open(file, "r") as f: - providers.append(json.loads(f.read())) - - for provider in providers: - validate(instance=provider, schema=schema) + provider = json.loads(f.read()) + providers.append((provider, file)) + + for provider, file in providers: + try: + validate(instance=provider, schema=schema) + except Exception as e: + raise ValueError(f"Error validating file {file}, provider {provider}: {e}") return providers diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index fd3931d21..bfb3de5e6 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -15,24 +15,141 @@ "if": { "properties": { "provider": { "const": "ollama" } - }, - "required": ["provider"] + } }, "then": { "properties": { "model": { "enum": ["llama3", "mistral", "gemma"] + }, + "config": { + "type": "object", + "properties": { + "mirostat": { + "type": "integer", + "minimum": 0, + "maximum": 2, + "default": 0 + }, + + "mirostat_eta": { + "type": "number", + "minimum": 0, + "maximum": 1, + "multipleOf": 0.01, + "default": 0.1 + }, + "microstat_tau": { + "type": "number", + "minimum": 0, + "maximum": 10, + "multipleOf": 0.1, + "default": 5 + }, + "num_ctx": { + "enum": [512, 1028, 2048, 4096], + "default": 2048, + "$comment": "this needs to be a per model value" + }, + "num_qga": { + "type": "integer", + "minimum": 1, + "maximum": 20, + "default": 8 + }, + "num_gpu": { + "type": "integer", + "minimum": 0, + "maximum": 16, + "default": 0 + }, + "num_thread": { + "type": "integer", + "minimum": 1, + "maximum": 16, + "default": 2 + }, + "repeat_last_n": { + "enum": [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], + "default": 64 + }, + "repeat_penalty": { + "type": "number", + "minimum": 1.0, + "maximum": 2.0, + "multipleOf": 0.1, + "default": 1.1 + }, + "temprature": { + "type": "number", + "minimum": 0, + "maximum": 1.5, + "multipleOf": 0.1, + "default": 0.8 + }, + "seed": { + "type": "integer", + "minimum": 0, + "maximum": 1000000, + "default": 0 + }, + "stop": { + "const": "" + }, + "tfs_z": { + "type": "number", + "minimum": 1.0, + "maximum": 2.0, + "multipleOf": 0.1, + "default": 1.0 + }, + "num_predict": { + "enum": [-2, -1, 32, 64, 128, 256, 512], + "default": 128 + }, + "top_k": { + "type": "integer", + "minimum": 2, + "maximum": 100, + "default": 40 + }, + "top_p": { + "type": "number", + "minimum": 0.5, + "maximum": 0.99, + "multipleOf": 0.01, + "default": 0.9 + } + }, + "required": [ + "mirostat", + "mirostat_eta", + "microstat_tau", + "num_ctx", + "num_qga", + "num_gpu", + "num_thread", + "repeat_last_n", + "repeat_penalty", + "temprature", + "seed", + "stop", + "tfs_z", + "num_predict", + "top_k", + "top_p" + ] } - } + }, + "required": ["config"] } }, { "if": { "properties": { "provider": { "const": "heuristai" } - }, - "required": ["provider"] + } }, "then": { "properties": { @@ -51,8 +168,7 @@ "if": { "properties": { "provider": { "const": "openai" } - }, - "required": ["provider"] + } }, "then": { "properties": { From 2e9a0d1014a1fa12eccd908754224af5333ba516 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 4 Sep 2024 10:46:32 -0300 Subject: [PATCH 11/75] fix default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../default_providers/ollama_gemma.json | 20 ++++++++++++++++++- .../default_providers/ollama_llama3.json | 20 ++++++++++++++++++- .../default_providers/ollama_mistral.json | 20 ++++++++++++++++++- .../node/create_nodes/providers_schema.json | 2 +- 4 files changed, 58 insertions(+), 4 deletions(-) diff --git a/backend/node/create_nodes/default_providers/ollama_gemma.json b/backend/node/create_nodes/default_providers/ollama_gemma.json index 7ab0250ff..c6b5eb810 100644 --- a/backend/node/create_nodes/default_providers/ollama_gemma.json +++ b/backend/node/create_nodes/default_providers/ollama_gemma.json @@ -1,4 +1,22 @@ { "provider": "ollama", - "model": "gemma" + "model": "gemma", + "config": { + "mirostat": 0, + "mirostat_eta": 0.1, + "microstat_tau": 5, + "num_ctx": 2048, + "num_qga": 8, + "num_gpu": 0, + "num_thread": 2, + "repeat_last_n": 64, + "repeat_penalty": 1.1, + "temprature": 0.8, + "seed": 0, + "stop": "", + "tfs_z": 1.0, + "num_predict": 128, + "top_k": 40, + "top_p": 0.9 + } } diff --git a/backend/node/create_nodes/default_providers/ollama_llama3.json b/backend/node/create_nodes/default_providers/ollama_llama3.json index 770958ce6..1e59247a0 100644 --- a/backend/node/create_nodes/default_providers/ollama_llama3.json +++ b/backend/node/create_nodes/default_providers/ollama_llama3.json @@ -1,4 +1,22 @@ { "provider": "ollama", - "model": "llama3" + "model": "llama3", + "config": { + "mirostat": 0, + "mirostat_eta": 0.1, + "microstat_tau": 5, + "num_ctx": 2048, + "num_qga": 8, + "num_gpu": 0, + "num_thread": 2, + "repeat_last_n": 64, + "repeat_penalty": 1.1, + "temprature": 0.8, + "seed": 0, + "stop": "", + "tfs_z": 1.0, + "num_predict": 128, + "top_k": 40, + "top_p": 0.9 + } } diff --git a/backend/node/create_nodes/default_providers/ollama_mistral.json b/backend/node/create_nodes/default_providers/ollama_mistral.json index 3a471f422..7232c036d 100644 --- a/backend/node/create_nodes/default_providers/ollama_mistral.json +++ b/backend/node/create_nodes/default_providers/ollama_mistral.json @@ -1,4 +1,22 @@ { "provider": "ollama", - "model": "mistral" + "model": "mistral", + "config": { + "mirostat": 0, + "mirostat_eta": 0.1, + "microstat_tau": 5, + "num_ctx": 2048, + "num_qga": 8, + "num_gpu": 0, + "num_thread": 2, + "repeat_last_n": 64, + "repeat_penalty": 1.1, + "temprature": 0.8, + "seed": 0, + "stop": "", + "tfs_z": 1.0, + "num_predict": 128, + "top_k": 40, + "top_p": 0.9 + } } diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index bfb3de5e6..54c091267 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -180,5 +180,5 @@ } ], "required": ["provider", "model"], - "additionalProperties": false + "additionalProperties": true } From 9e23c1361556ee386e076fdc9ed37de0335e4b15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 4 Sep 2024 10:47:41 -0300 Subject: [PATCH 12/75] fix schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/providers_schema.json | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 54c091267..85180b13a 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -8,6 +8,9 @@ }, "model": { "type": "string" + }, + "config": { + "type": "object" } }, "allOf": [ @@ -180,5 +183,5 @@ } ], "required": ["provider", "model"], - "additionalProperties": true + "additionalProperties": false } From d3dd24489c1d11d674e80b0e240c21dcaafdfdf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 4 Sep 2024 14:33:08 -0300 Subject: [PATCH 13/75] extend schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../heuristai_dolphin-2.9-llama3-8b.json | 6 ++++- .../heuristai_meta-llamallama-2-70b-chat.json | 6 ++++- ...ristai_mistralaimixtral-8x7b-instruct.json | 6 ++++- .../heuristai_openhermes-2-yi-34b-gptq.json | 6 ++++- .../node/create_nodes/providers_schema.json | 23 ++++++++++++++++++- 5 files changed, 42 insertions(+), 5 deletions(-) diff --git a/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json index ecaf04ea9..57e6b536f 100644 --- a/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json +++ b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json @@ -1,4 +1,8 @@ { "provider": "heuristai", - "model": "dolphin-2.9-llama3-8b" + "model": "dolphin-2.9-llama3-8b", + "config": { + "temperature": 0.75, + "max_tokens": 500 + } } diff --git a/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json index 20b7b3ba8..a405b094e 100644 --- a/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json +++ b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json @@ -1,4 +1,8 @@ { "provider": "heuristai", - "model": "meta-llama/llama-2-70b-chat" + "model": "meta-llama/llama-2-70b-chat", + "config": { + "temperature": 0.75, + "max_tokens": 500 + } } diff --git a/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json index 89abba0fb..c65df3de8 100644 --- a/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json +++ b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json @@ -1,4 +1,8 @@ { "provider": "heuristai", - "model": "mistralai/mixtral-8x7b-instruct" + "model": "mistralai/mixtral-8x7b-instruct", + "config": { + "temperature": 0.75, + "max_tokens": 500 + } } diff --git a/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json index 4955154ef..9bf75a184 100644 --- a/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json +++ b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json @@ -1,4 +1,8 @@ { "provider": "heuristai", - "model": "openhermes-2-yi-34b-gptq" + "model": "openhermes-2-yi-34b-gptq", + "config": { + "temperature": 0.75, + "max_tokens": 500 + } } diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 85180b13a..9a7c46a7f 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -163,8 +163,29 @@ "openhermes-2-yi-34b-gptq", "dolphin-2.9-llama3-8b" ] + }, + "config": { + "type": "object", + "properties": { + "temperature": { + "type": "number", + "minimum": 0, + "maximum": 1, + "multipleOf": 0.05, + "default": 0.75 + }, + "max_tokens": { + "type": "integer", + "minimum": 100, + "maximum": 2000, + "multipleOf": 10, + "default": 500 + } + }, + "required": ["temperature", "max_tokens"] } - } + }, + "required": ["config"] } }, { From 7b7d63b9e214c4749f72219a5676bde01571ca87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 4 Sep 2024 14:48:12 -0300 Subject: [PATCH 14/75] extend schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../openai_gpt-3.5-turbo.json | 3 ++- .../default_providers/openai_gpt-4.json | 3 ++- .../default_providers/openai_gpt-4o-mini.json | 3 ++- .../default_providers/openai_gpt-4o.json | 3 ++- backend/node/create_nodes/providers.py | 4 +++- .../node/create_nodes/providers_schema.json | 20 +++++++++++++------ 6 files changed, 25 insertions(+), 11 deletions(-) diff --git a/backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json b/backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json index 753cfd935..d645c7cbb 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json @@ -1,4 +1,5 @@ { "provider": "openai", - "model": "gpt-3.5-turbo" + "model": "gpt-3.5-turbo", + "config": "" } diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4.json b/backend/node/create_nodes/default_providers/openai_gpt-4.json index 5fbd7ad26..bb2ed93f1 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4.json @@ -1,4 +1,5 @@ { "provider": "openai", - "model": "gpt-4" + "model": "gpt-4", + "config": "" } diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json b/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json index 2a8a5a2fd..a5a387cd7 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json @@ -1,4 +1,5 @@ { "provider": "openai", - "model": "gpt-4o-mini" + "model": "gpt-4o-mini", + "config": "" } diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4o.json b/backend/node/create_nodes/default_providers/openai_gpt-4o.json index 9a58a2d30..d0f89403d 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4o.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4o.json @@ -1,4 +1,5 @@ { "provider": "openai", - "model": "gpt-4o" + "model": "gpt-4o", + "config": "" } diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index de4b3a5ec..ffdf40e1c 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -5,7 +5,8 @@ from hypothesis.errors import NonInteractiveExampleWarning from hypothesis_jsonschema import from_schema -from jsonschema import validate +from jsonschema import validate, Draft202012Validator +from jsonschema.protocols import Validator current_directory = os.path.dirname(os.path.abspath(__file__)) schema_file = os.path.join(current_directory, "providers_schema.json") @@ -19,6 +20,7 @@ def get_schema() -> dict: with open(schema_file, "r") as f: schema = json.loads(f.read()) + Draft202012Validator.check_schema(schema) return schema diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 9a7c46a7f..8098dcf0a 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -10,7 +10,14 @@ "type": "string" }, "config": { - "type": "object" + "oneOf": [ + { + "type": "object" + }, + { + "const": "" + } + ] } }, "allOf": [ @@ -144,8 +151,7 @@ "top_p" ] } - }, - "required": ["config"] + } } }, { @@ -184,8 +190,7 @@ }, "required": ["temperature", "max_tokens"] } - }, - "required": ["config"] + } } }, { @@ -198,11 +203,14 @@ "properties": { "model": { "enum": ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4o-mini"] + }, + "config": { + "const": "" } } } } ], - "required": ["provider", "model"], + "required": ["provider", "model", "config"], "additionalProperties": false } From fd36122a71e80295634933d29ab65cdef45b54aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 4 Sep 2024 16:06:43 -0300 Subject: [PATCH 15/75] add llm provider object MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/database_handler/llm_providers.py | 41 +++++++++++++++ .../db38e78684a8_add_providers_table.py | 50 +++++++++++++++++++ backend/database_handler/models.py | 19 +++++++ backend/domain/types.py | 12 +++++ backend/node/create_nodes/create_nodes.py | 7 +-- backend/node/create_nodes/providers.py | 42 +++++++++------- backend/protocol_rpc/endpoint_generator.py | 2 +- backend/protocol_rpc/endpoints.py | 10 ++++ backend/protocol_rpc/server.py | 9 +++- backend/protocol_rpc/types.py | 2 +- 10 files changed, 169 insertions(+), 25 deletions(-) create mode 100644 backend/database_handler/llm_providers.py create mode 100644 backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py create mode 100644 backend/domain/types.py diff --git a/backend/database_handler/llm_providers.py b/backend/database_handler/llm_providers.py new file mode 100644 index 000000000..b62bfca49 --- /dev/null +++ b/backend/database_handler/llm_providers.py @@ -0,0 +1,41 @@ +from backend.domain.types import LLMProvider +from backend.node.create_nodes.providers import get_default_providers +from .models import LLMProviderDBModel +from sqlalchemy.orm import Session + + +class LLMProviderRegistry: + def __init__(self, session: Session): + self.session = session + + def reset_defaults(self): + """Reset all providers to their default values.""" + self.session.query(LLMProviderDBModel).delete() + + providers = get_default_providers() + for provider in providers: + self.session.add(_to_db_model(provider)) + + self.session.commit() + + def get_all(self) -> list[LLMProvider]: + return [ + _to_domain(provider) + for provider in self.session.query(LLMProviderDBModel).all() + ] + + +def _to_domain(db_model: LLMProvider) -> LLMProvider: + return LLMProvider( + provider=db_model.provider, + model=db_model.model, + config=db_model.config, + ) + + +def _to_db_model(domain: LLMProvider) -> LLMProviderDBModel: + return LLMProviderDBModel( + provider=domain.provider, + model=domain.model, + config=domain.config, + ) diff --git a/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py b/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py new file mode 100644 index 000000000..e70946954 --- /dev/null +++ b/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py @@ -0,0 +1,50 @@ +"""add providers table + +Revision ID: db38e78684a8 +Revises: d9ddc7436122 +Create Date: 2024-09-04 15:16:11.586574 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "db38e78684a8" +down_revision: Union[str, None] = "d9ddc7436122" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "llm_provider", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("provider", sa.String(length=255), nullable=False), + sa.Column("model", sa.String(length=255), nullable=False), + sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id", name="llm_provider_pkey"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("llm_provider") + # ### end Alembic commands ### diff --git a/backend/database_handler/models.py b/backend/database_handler/models.py index 87642a4e1..67c1452f6 100644 --- a/backend/database_handler/models.py +++ b/backend/database_handler/models.py @@ -114,3 +114,22 @@ class Validators(Base): created_at: Mapped[Optional[datetime.datetime]] = mapped_column( DateTime(True), server_default=func.current_timestamp(), init=False ) + + +class LLMProviderDBModel(Base): + __tablename__ = "llm_provider" + __table_args__ = (PrimaryKeyConstraint("id", name="llm_provider_pkey"),) + + id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False) + provider: Mapped[str] = mapped_column(String(255)) + model: Mapped[str] = mapped_column(String(255)) + config: Mapped[dict | str] = mapped_column(JSONB) + created_at: Mapped[datetime.datetime] = mapped_column( + DateTime(True), server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(True), + init=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) diff --git a/backend/domain/types.py b/backend/domain/types.py new file mode 100644 index 000000000..e01a83fcb --- /dev/null +++ b/backend/domain/types.py @@ -0,0 +1,12 @@ +# Types from our domain +# Trying to follow [hexagonal architecture](https://en.wikipedia.org/wiki/Hexagonal_architecture_(software)) or layered architecture. +# These types should not depend on any other layer. + +from dataclasses import dataclass + + +@dataclass() +class LLMProvider: + provider: str + model: str + config: dict diff --git a/backend/node/create_nodes/create_nodes.py b/backend/node/create_nodes/create_nodes.py index 1991c6d7a..2783c088e 100644 --- a/backend/node/create_nodes/create_nodes.py +++ b/backend/node/create_nodes/create_nodes.py @@ -55,14 +55,15 @@ def get_random_provider_using_weights( def get_provider_models( defaults: dict, provider: str, get_ollama_url: Callable[[str], str] ) -> list: + get_default_providers if provider == "ollama": ollama_models_result = requests.get(get_ollama_url("tags")).json() - ollama_models = [] + installed_ollama_models = [] for ollama_model in ollama_models_result["models"]: # "llama3:latest" => "llama3" - ollama_models.append(ollama_model["name"].split(":")[0]) - return ollama_models + installed_ollama_models.append(ollama_model["name"].split(":")[0]) + return installed_ollama_models elif provider == "openai": return defaults["openai_models"].split(",") diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index ffdf40e1c..ba3960a04 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -1,22 +1,20 @@ import json import os import warnings -from typing import List, TypeAlias +from typing import List -from hypothesis.errors import NonInteractiveExampleWarning -from hypothesis_jsonschema import from_schema +# from hypothesis.errors import NonInteractiveExampleWarning +# from hypothesis_jsonschema import from_schema from jsonschema import validate, Draft202012Validator -from jsonschema.protocols import Validator + +from backend.domain.types import LLMProvider current_directory = os.path.dirname(os.path.abspath(__file__)) schema_file = os.path.join(current_directory, "providers_schema.json") default_providers_folder = os.path.join(current_directory, "default_providers") -Provider: TypeAlias = dict - def get_schema() -> dict: - with open(schema_file, "r") as f: schema = json.loads(f.read()) @@ -24,8 +22,7 @@ def get_schema() -> dict: return schema -# TODO: cast providers into some kind of class. We know that all providers have `provider` and `model` keys -def get_default_providers() -> List[Provider]: +def get_default_providers() -> List[LLMProvider]: schema = get_schema() files = [ @@ -38,24 +35,31 @@ def get_default_providers() -> List[Provider]: for file in files: with open(file, "r") as f: provider = json.loads(f.read()) - providers.append((provider, file)) - - for provider, file in providers: try: validate(instance=provider, schema=schema) except Exception as e: raise ValueError(f"Error validating file {file}, provider {provider}: {e}") + providers.append(_to_domain(provider)) + return providers -def get_random_provider() -> Provider: - schema = get_schema() +def _to_domain(provider: dict) -> LLMProvider: + return LLMProvider( + provider=provider["provider"], + model=provider["model"], + config=provider["config"], + ) + + +# def get_random_provider() -> LLMProvider: +# schema = get_schema() - with warnings.catch_warnings(): - warnings.simplefilter("ignore", NonInteractiveExampleWarning) - value = from_schema(schema).example() +# with warnings.catch_warnings(): +# warnings.simplefilter("ignore", NonInteractiveExampleWarning) +# value = from_schema(schema).example() - validate(instance=value, schema=schema) +# validate(instance=value, schema=schema) - return value +# return _to_domain(value) diff --git a/backend/protocol_rpc/endpoint_generator.py b/backend/protocol_rpc/endpoint_generator.py index 96ae2d81e..17d0f090b 100644 --- a/backend/protocol_rpc/endpoint_generator.py +++ b/backend/protocol_rpc/endpoint_generator.py @@ -17,7 +17,7 @@ def generate_rpc_endpoint( ) -> Callable: @jsonrpc.method(function.__name__) @wraps(function) - def endpoint(*args, **kwargs): + def endpoint(*args, **kwargs) -> dict[str]: shouldPrintInfoLogs = ( function.__name__ not in config.get_disabled_info_logs_endpoints() ) diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index ea3902876..a948444a6 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -6,6 +6,7 @@ from sqlalchemy import Table from backend.database_handler.db_client import DBClient +from backend.database_handler.llm_providers import LLMProviderRegistry from backend.database_handler.models import Base from backend.protocol_rpc.configuration import GlobalConfiguration from backend.protocol_rpc.message_handler.base import MessageHandler @@ -120,6 +121,11 @@ def get_contract_schema_for_code( return node.get_contract_schema(contract_code) +# TODO: this shouldn't return a `dict`, but I'm getting `TypeError: return type of dict must be a type; got NoneType instead` +def reset_defaults_llm_providers(llm_provider_registry: LLMProviderRegistry) -> dict: + llm_provider_registry.reset_defaults() + + def get_providers_and_models(config: GlobalConfiguration) -> dict: default_config = get_default_config_for_providers_and_nodes() providers = get_providers() @@ -380,6 +386,7 @@ def register_all_rpc_endpoints( accounts_manager: AccountsManager, transactions_processor: TransactionsProcessor, validators_registry: ValidatorsRegistry, + llm_provider_registry: LLMProviderRegistry, config: GlobalConfiguration, ): register_rpc_endpoint = partial(generate_rpc_endpoint, jsonrpc, msg_handler, config) @@ -403,6 +410,9 @@ def register_all_rpc_endpoints( register_rpc_endpoint_for_partial(get_contract_schema_for_code, msg_handler) register_rpc_endpoint_for_partial(get_providers_and_models, config) + register_rpc_endpoint_for_partial( + reset_defaults_llm_providers, llm_provider_registry + ) register_rpc_endpoint_for_partial( create_validator, validators_registry, accounts_manager ) diff --git a/backend/protocol_rpc/server.py b/backend/protocol_rpc/server.py index 007c8fc42..01312c2d9 100644 --- a/backend/protocol_rpc/server.py +++ b/backend/protocol_rpc/server.py @@ -9,6 +9,7 @@ from flask_socketio import SocketIO from flask_cors import CORS from flask_sqlalchemy import SQLAlchemy +from backend.database_handler.llm_providers import LLMProviderRegistry from backend.protocol_rpc.configuration import GlobalConfiguration from backend.protocol_rpc.message_handler.base import MessageHandler from backend.protocol_rpc.endpoints import register_all_rpc_endpoints @@ -39,13 +40,16 @@ def create_app(): sqlalchemy_db.init_app(app) CORS(app, resources={r"/api/*": {"origins": "*"}}, intercept_exceptions=False) - jsonrpc = JSONRPC(app, "/api", enable_web_browsable_api=True) + jsonrpc = JSONRPC( + app, "/api", enable_web_browsable_api=True + ) # check it out at http://localhost:4000/api/browse/#/ socketio = SocketIO(app, cors_allowed_origins="*") msg_handler = MessageHandler(app, socketio) genlayer_db_client = DBClient(database_name_seed) transactions_processor = TransactionsProcessor(sqlalchemy_db.session) accounts_manager = AccountsManager(sqlalchemy_db.session) validators_registry = ValidatorsRegistry(sqlalchemy_db.session) + llm_provider_registry = LLMProviderRegistry(sqlalchemy_db.session) consensus = ConsensusAlgorithm(genlayer_db_client, msg_handler) return ( @@ -58,6 +62,7 @@ def create_app(): transactions_processor, validators_registry, consensus, + llm_provider_registry, ) @@ -72,6 +77,7 @@ def create_app(): transactions_processor, validators_registry, consensus, + llm_provider_registry, ) = create_app() register_all_rpc_endpoints( jsonrpc, @@ -80,6 +86,7 @@ def create_app(): accounts_manager, transactions_processor, validators_registry, + llm_provider_registry, config=GlobalConfiguration(), ) diff --git a/backend/protocol_rpc/types.py b/backend/protocol_rpc/types.py index 1a8b4e4e8..a0400a251 100644 --- a/backend/protocol_rpc/types.py +++ b/backend/protocol_rpc/types.py @@ -15,7 +15,7 @@ class EndpointResult: data: dict = field(default_factory=dict) exception: Exception = None - def to_json(self): + def to_json(self) -> dict[str]: return { "status": self.status.value, "message": self.message, From ed85d6d7d09fd6a8aca58018a42ef6b42e05a3dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 4 Sep 2024 16:43:33 -0300 Subject: [PATCH 16/75] add more endpoints for configuring providers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/database_handler/llm_providers.py | 18 +++++++++ backend/node/create_nodes/providers.py | 5 +++ backend/protocol_rpc/endpoints.py | 48 +++++++++++++++++------ 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/backend/database_handler/llm_providers.py b/backend/database_handler/llm_providers.py index b62bfca49..9d2256a11 100644 --- a/backend/database_handler/llm_providers.py +++ b/backend/database_handler/llm_providers.py @@ -24,6 +24,24 @@ def get_all(self) -> list[LLMProvider]: for provider in self.session.query(LLMProviderDBModel).all() ] + def add(self, provider: LLMProvider) -> int: + model = _to_db_model(provider) + self.session.add(model) + self.session.commit() + return model.id + + def edit(self, id: int, provider: LLMProvider): + self.session.query(LLMProviderDBModel).filter( + LLMProviderDBModel.id == id + ).update(_to_db_model(provider)) + self.session.commit() + + def delete(self, id: int): + self.session.query(LLMProviderDBModel).filter( + LLMProviderDBModel.id == id + ).delete() + self.session.commit() + def _to_domain(db_model: LLMProvider) -> LLMProvider: return LLMProvider( diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index ba3960a04..dda378338 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -22,6 +22,11 @@ def get_schema() -> dict: return schema +def validate_provider(provider: LLMProvider): + schema = get_schema() + validate(instance=provider.__dict__, schema=schema) + + def get_default_providers() -> List[LLMProvider]: schema = get_schema() diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index a948444a6..255939107 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -8,6 +8,8 @@ from backend.database_handler.db_client import DBClient from backend.database_handler.llm_providers import LLMProviderRegistry from backend.database_handler.models import Base +from backend.domain.types import LLMProvider +from backend.node.create_nodes.providers import validate_provider from backend.protocol_rpc.configuration import GlobalConfiguration from backend.protocol_rpc.message_handler.base import MessageHandler from backend.database_handler.accounts_manager import AccountsManager @@ -121,20 +123,41 @@ def get_contract_schema_for_code( return node.get_contract_schema(contract_code) -# TODO: this shouldn't return a `dict`, but I'm getting `TypeError: return type of dict must be a type; got NoneType instead` +# TODO: these endpoints shouldn't return a `dict`, but I'm getting `TypeError: return type of dict must be a type; got NoneType instead` def reset_defaults_llm_providers(llm_provider_registry: LLMProviderRegistry) -> dict: llm_provider_registry.reset_defaults() -def get_providers_and_models(config: GlobalConfiguration) -> dict: - default_config = get_default_config_for_providers_and_nodes() - providers = get_providers() - providers_and_models = {} - for provider in providers: - providers_and_models[provider] = get_provider_models( - default_config["providers"], provider, config.get_ollama_url - ) - return providers_and_models +def get_providers_and_models(llm_provider_registry: LLMProviderRegistry) -> dict: + return llm_provider_registry.get_all() + + +def add_provider(llm_provider_registry: LLMProviderRegistry, params: dict) -> dict: + provider = LLMProvider( + provider=params["provider"], + model=params["model"], + config=params["config"], + ) + validate_provider(provider) + + return llm_provider_registry.add(provider) + + +def edit_provider( + llm_provider_registry: LLMProviderRegistry, id: int, params: dict +) -> dict: + provider = LLMProvider( + provider=params["provider"], + model=params["model"], + config=params["config"], + ) + validate_provider(provider) + + llm_provider_registry.edit(id, provider) + + +def delete_provider(llm_provider_registry: LLMProviderRegistry, id: int) -> dict: + llm_provider_registry.delete(id) def create_validator( @@ -409,10 +432,13 @@ def register_all_rpc_endpoints( ) register_rpc_endpoint_for_partial(get_contract_schema_for_code, msg_handler) - register_rpc_endpoint_for_partial(get_providers_and_models, config) + register_rpc_endpoint_for_partial(get_providers_and_models, llm_provider_registry) register_rpc_endpoint_for_partial( reset_defaults_llm_providers, llm_provider_registry ) + register_rpc_endpoint_for_partial(add_provider, llm_provider_registry) + register_rpc_endpoint_for_partial(edit_provider, llm_provider_registry) + register_rpc_endpoint_for_partial(delete_provider, llm_provider_registry) register_rpc_endpoint_for_partial( create_validator, validators_registry, accounts_manager ) From fd5c6d3c852e0ebbe6340e411b18d59b98eeceec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 5 Sep 2024 12:02:37 -0300 Subject: [PATCH 17/75] test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .vscode/settings.json | 6 ++- backend/node/create_nodes/create_nodes.py | 51 ++++++++++++------- backend/node/create_nodes/providers.py | 19 +++---- .../node/create_nodes/providers_schema.json | 2 + backend/protocol_rpc/endpoints.py | 4 +- tests/unit/test_providers.py | 43 ++++++++++++++++ 6 files changed, 96 insertions(+), 29 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index e65f2fff9..cdd85ce2b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,5 +4,9 @@ }, "python.testing.pytestArgs": ["tests", "backend"], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "sonarlint.connectedMode.project": { + "connectionId": "YeagerAI", + "projectKey": "yeagerai_genlayer-simulator" + } } diff --git a/backend/node/create_nodes/create_nodes.py b/backend/node/create_nodes/create_nodes.py index 2783c088e..8f92cffc0 100644 --- a/backend/node/create_nodes/create_nodes.py +++ b/backend/node/create_nodes/create_nodes.py @@ -1,16 +1,18 @@ import os import json import re -from typing import Callable +from typing import Callable, List import requests import numpy as np from random import random, choice, uniform from dotenv import load_dotenv +from backend.domain.types import LLMProvider + load_dotenv() -default_provider_key_regex = r"" +default_provider_key_regex = r"^(|)$" def base_node_json(provider: str, model: str) -> dict: @@ -55,8 +57,6 @@ def get_random_provider_using_weights( def get_provider_models( defaults: dict, provider: str, get_ollama_url: Callable[[str], str] ) -> list: - get_default_providers - if provider == "ollama": ollama_models_result = requests.get(get_ollama_url("tags")).json() installed_ollama_models = [] @@ -122,25 +122,40 @@ def num_decimal_places(number: float) -> int: def random_validator_config( - get_ollama_url: Callable[[str], str], providers: list = None -): - providers = providers or [] - - if len(providers) == 0: - providers = get_providers() - default_config = get_default_config_for_providers_and_nodes() - config = get_config_with_specific_providers(default_config, providers) - ollama_models = get_provider_models({}, "ollama", get_ollama_url) - + get_ollama_url: Callable[[str], str], + get_stored_providers: Callable[[], List[LLMProvider]], + provider_names: List[str] = None, +) -> dict: + provider_names = provider_names or [] + + stored_providers = get_stored_providers() + providers_to_use = stored_providers + + if len(provider_names) > 0: + providers_to_use = [ + provider + for provider in stored_providers + if provider.provider in provider_names + ] + + # default_config = get_default_config_for_providers_and_nodes() + # config = get_config_with_specific_providers(default_config, providers) + # ollama_models = get_provider_models({}, "ollama", get_ollama_url) + + # See if they have an the provider's keys. + + # TODO: when should we check which models are available? Maybe when filling up the database? Should we check every time? + # TODO: this methods for checking the providers are decoupled from the actual configuration and schema of the providers. This means that modifications need to be done in two places. + ollama_models = [ + provider.model for provider in provider_names if provider.provider == "ollama" + ] if ( not len(ollama_models) - and os.environ["OPENAIKEY"] == default_provider_key_regex - and os.environ["HEURISTAIAPIKEY"] == default_provider_key_regex + and re.match(default_provider_key_regex, os.environ.get("OPENAIKEY", "")) + and re.match(default_provider_key_regex, os.environ.get("HEURISTAIAPIKEY", "")) ): raise Exception("No providers avaliable.") - # See if they have an OpenAPI key - # heuristic_models_result = requests.get(os.environ['HEURISTAIMODELSURL']).json() # heuristic_models = [] # for entry in heuristic_models_result: diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index dda378338..f64fafcc5 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -3,8 +3,9 @@ import warnings from typing import List -# from hypothesis.errors import NonInteractiveExampleWarning -# from hypothesis_jsonschema import from_schema +from hypothesis import strategies as st +from hypothesis.errors import NonInteractiveExampleWarning +from hypothesis_jsonschema import from_schema from jsonschema import validate, Draft202012Validator from backend.domain.types import LLMProvider @@ -58,13 +59,13 @@ def _to_domain(provider: dict) -> LLMProvider: ) -# def get_random_provider() -> LLMProvider: -# schema = get_schema() +def get_random_provider() -> LLMProvider: + schema = get_schema() -# with warnings.catch_warnings(): -# warnings.simplefilter("ignore", NonInteractiveExampleWarning) -# value = from_schema(schema).example() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", NonInteractiveExampleWarning) + value = from_schema(schema).example() -# validate(instance=value, schema=schema) + validate(instance=value, schema=schema) -# return _to_domain(value) + return _to_domain(value) diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 8098dcf0a..987980f79 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -35,6 +35,7 @@ }, "config": { "type": "object", + "additionalProperties": false, "properties": { "mirostat": { "type": "integer", @@ -172,6 +173,7 @@ }, "config": { "type": "object", + "additionalProperties": false, "properties": { "temperature": { "type": "number", diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index 255939107..e2e63625d 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -210,7 +210,9 @@ def create_random_validators( for _ in range(count): stake = random.uniform(min_stake, max_stake) validator_address = accounts_manager.create_new_account().address - details = random_validator_config(config.get_ollama_url, providers=providers) + details = random_validator_config( + config.get_ollama_url, provider_names=providers + ) new_validator = validators_registry.create_validator( validator_address, stake, diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py index 26b329631..6cc2826ee 100644 --- a/tests/unit/test_providers.py +++ b/tests/unit/test_providers.py @@ -1,6 +1,10 @@ +from hypothesis import HealthCheck, given, settings +from hypothesis.errors import HypothesisDeprecationWarning +from hypothesis_jsonschema import from_schema from backend.node.create_nodes.providers import ( get_default_providers, get_random_provider, + get_schema, ) @@ -17,3 +21,42 @@ def test_default_providers_valid(): # assert provider is not None # assert "provider" in provider # assert "model" in provider + + +@settings(max_examples=1) +@given( + from_schema(get_schema()), +) +def test_random_provider(value): + print() + print(value) + print() + # assert False + assert value is not None + + +return_value = None + + +def custom(): + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", HypothesisDeprecationWarning) + + @settings(max_examples=1, suppress_health_check=(HealthCheck.return_value,)) + @given( + from_schema(get_schema()), + ) + def inner(value): + global return_value + return_value = value + + inner() + return return_value + + +print(custom()) +print(custom()) +print(custom()) +print(custom()) From 6cab720b01e0663880170b04ac63f5751d4ce292 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 5 Sep 2024 13:39:38 -0300 Subject: [PATCH 18/75] add random creation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/providers.py | 31 +++++++++++++-- tests/unit/test_providers.py | 55 ++------------------------ 2 files changed, 32 insertions(+), 54 deletions(-) diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index f64fafcc5..4c080cd67 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -3,10 +3,11 @@ import warnings from typing import List -from hypothesis import strategies as st -from hypothesis.errors import NonInteractiveExampleWarning +from hypothesis import HealthCheck, given, settings +from hypothesis.errors import (HypothesisDeprecationWarning, + NonInteractiveExampleWarning) from hypothesis_jsonschema import from_schema -from jsonschema import validate, Draft202012Validator +from jsonschema import Draft202012Validator, validate from backend.domain.types import LLMProvider @@ -69,3 +70,27 @@ def get_random_provider() -> LLMProvider: validate(instance=value, schema=schema) return _to_domain(value) + + +def create_random_providers(amount: int) -> list[LLMProvider]: + import warnings + + return_value = [] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", HypothesisDeprecationWarning) + + @settings( + max_examples=amount, suppress_health_check=(HealthCheck.return_value,) + ) + @given( + from_schema(get_schema()), + ) + def inner(value): + nonlocal return_value + provider = _to_domain(value) + validate_provider(provider) + return_value.append(provider) + + inner() + return return_value diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py index 6cc2826ee..3ba86dbf0 100644 --- a/tests/unit/test_providers.py +++ b/tests/unit/test_providers.py @@ -1,10 +1,6 @@ -from hypothesis import HealthCheck, given, settings -from hypothesis.errors import HypothesisDeprecationWarning -from hypothesis_jsonschema import from_schema from backend.node.create_nodes.providers import ( + create_random_providers, get_default_providers, - get_random_provider, - get_schema, ) @@ -14,49 +10,6 @@ def test_default_providers_valid(): assert len(providers) > 0 -# Takes too long to run -# def test_get_random_provider(): -# provider = get_random_provider() - -# assert provider is not None -# assert "provider" in provider -# assert "model" in provider - - -@settings(max_examples=1) -@given( - from_schema(get_schema()), -) -def test_random_provider(value): - print() - print(value) - print() - # assert False - assert value is not None - - -return_value = None - - -def custom(): - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", HypothesisDeprecationWarning) - - @settings(max_examples=1, suppress_health_check=(HealthCheck.return_value,)) - @given( - from_schema(get_schema()), - ) - def inner(value): - global return_value - return_value = value - - inner() - return return_value - - -print(custom()) -print(custom()) -print(custom()) -print(custom()) +def test_create_random_providers(): + # Note: testing this is very slow, so we only test it once + create_random_providers(1) From 53e919edb2dceae5dbf972b9504a767bf83d976a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 5 Sep 2024 14:23:15 -0300 Subject: [PATCH 19/75] use hypothesis for random generation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/create_nodes.py | 115 ++-------------------- backend/node/create_nodes/providers.py | 24 ++--- backend/protocol_rpc/endpoints.py | 16 +-- tests/unit/test_providers.py | 2 +- 4 files changed, 27 insertions(+), 130 deletions(-) diff --git a/backend/node/create_nodes/create_nodes.py b/backend/node/create_nodes/create_nodes.py index 8f92cffc0..ec63ea20b 100644 --- a/backend/node/create_nodes/create_nodes.py +++ b/backend/node/create_nodes/create_nodes.py @@ -3,12 +3,11 @@ import re from typing import Callable, List import requests -import numpy as np -from random import random, choice, uniform from dotenv import load_dotenv from backend.domain.types import LLMProvider +from backend.node.create_nodes.providers import create_random_providers load_dotenv() @@ -19,41 +18,6 @@ def base_node_json(provider: str, model: str) -> dict: return {"provider": provider, "model": model, "config": {}} -def get_random_provider_using_weights( - defaults: dict[str], get_ollama_url: Callable[[str], str] -) -> str: - # remove providers if no api key - provider_weights: dict[str, float] = defaults["provider_weights"] - - if "openai" in provider_weights and ( - "OPENAIKEY" not in os.environ - or re.match(default_provider_key_regex, os.environ["OPENAIKEY"]) - ): - provider_weights.pop("openai") - if "heuristai" in provider_weights and ( - "HEURISTAIAPIKEY" not in os.environ.get() - or re.match(os.environ["HEURISTAIAPIKEY"]) - ): - provider_weights.pop("heuristai") - if ( - "ollama" in provider_weights - and get_provider_models({}, "ollama", get_ollama_url) == [] - ): - provider_weights.pop("ollama") - - if len(provider_weights) == 0: - raise Exception("No providers avaliable") - - total_weight = sum(provider_weights.values()) - random_num = uniform(0, total_weight) - - cumulative_weight = 0 - for key, weight in provider_weights.items(): - cumulative_weight += weight - if random_num <= cumulative_weight: - return key - - def get_provider_models( defaults: dict, provider: str, get_ollama_url: Callable[[str], str] ) -> list: @@ -123,12 +87,14 @@ def num_decimal_places(number: float) -> int: def random_validator_config( get_ollama_url: Callable[[str], str], - get_stored_providers: Callable[[], List[LLMProvider]], + # get_stored_providers: Callable[[], List[LLMProvider]], provider_names: List[str] = None, -) -> dict: + amount: int = 1, +) -> List[LLMProvider]: provider_names = provider_names or [] - stored_providers = get_stored_providers() + # stored_providers = get_stored_providers() + stored_providers = [] providers_to_use = stored_providers if len(provider_names) > 0: @@ -161,70 +127,9 @@ def random_validator_config( # for entry in heuristic_models_result: # heuristic_models.append(entry['name']) - provider = get_random_provider_using_weights(config["providers"], get_ollama_url) - options = get_options(provider, config) - - if provider == "openai": - openai_model = choice( - get_provider_models(config["providers"], "openai", get_ollama_url) - ) - node_config = base_node_json("openai", openai_model) - - elif provider == "ollama": - node_config = base_node_json("ollama", choice(ollama_models)) - - for option, option_config in options.items(): - # Just pass the string (for "stop") - if isinstance(option_config, str): - node_config["config"][option] = option_config - # Create a random value - elif isinstance(option_config, dict): - if random() > config["providers"]["chance_of_default_value"]: - random_value = None - if isinstance(option_config["step"], str): - random_value = choice(option_config["step"].split(",")) - node_config["config"][option] = int(random_value) - else: - random_value = choice( - np.arange( - option_config["min"], - option_config["max"], - option_config["step"], - ) - ) - if isinstance(random_value, np.int64): - random_value = int(random_value) - if isinstance(random_value, np.float64): - random_value = float(random_value) - node_config["config"][option] = round( - random_value, num_decimal_places(option_config["step"]) - ) - else: - raise Exception("Option is not a dict or str (" + option + ")") + # provider = get_random_provider_using_weights(config["providers"], get_ollama_url) + # options = get_options(provider, config) - elif provider == "heuristai": - heuristic_model = choice( - get_provider_models(config["providers"], "heuristai", get_ollama_url) - ) - node_config = base_node_json("heuristai", heuristic_model) - for option, option_config in options.items(): - if random() > config["providers"]["chance_of_default_value"]: - random_value = choice( - np.arange( - option_config["min"], - option_config["max"], - option_config["step"], - ) - ) - if isinstance(random_value, np.int64): - random_value = int(random_value) - if isinstance(random_value, np.float64): - random_value = float(random_value) - node_config["config"][option] = round( - random_value, num_decimal_places(option_config["step"]) - ) - - else: - raise Exception("Provider " + provider + " is not specified in defaults") + # raise Exception("Provider " + provider + " is not specified in defaults") - return node_config + return create_random_providers(amount) # TODO: filter by provider and availability diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index 4c080cd67..dc5cf87b6 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -1,11 +1,9 @@ import json import os -import warnings from typing import List from hypothesis import HealthCheck, given, settings -from hypothesis.errors import (HypothesisDeprecationWarning, - NonInteractiveExampleWarning) +from hypothesis.errors import HypothesisDeprecationWarning from hypothesis_jsonschema import from_schema from jsonschema import Draft202012Validator, validate @@ -60,19 +58,11 @@ def _to_domain(provider: dict) -> LLMProvider: ) -def get_random_provider() -> LLMProvider: - schema = get_schema() - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", NonInteractiveExampleWarning) - value = from_schema(schema).example() - - validate(instance=value, schema=schema) - - return _to_domain(value) - - def create_random_providers(amount: int) -> list[LLMProvider]: + """ + Creates random providers deriving them from the json schema. + Internally uses hypothesis to generate the data, which is hacky since it's meant to be a testing library. + """ import warnings return_value = [] @@ -84,7 +74,9 @@ def create_random_providers(amount: int) -> list[LLMProvider]: max_examples=amount, suppress_health_check=(HealthCheck.return_value,) ) @given( - from_schema(get_schema()), + from_schema( + get_schema(), + ), ) def inner(value): nonlocal return_value diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index e2e63625d..71d5c665e 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -181,13 +181,13 @@ def create_random_validator( stake: int, ) -> dict: validator_address = accounts_manager.create_new_account().address - details = random_validator_config(config.get_ollama_url) + details = random_validator_config(config.get_ollama_url)[0] response = validators_registry.create_validator( validator_address, stake, - details["provider"], - details["model"], - details["config"], + details.provider, + details.model, + details.config, ) return response @@ -212,13 +212,13 @@ def create_random_validators( validator_address = accounts_manager.create_new_account().address details = random_validator_config( config.get_ollama_url, provider_names=providers - ) + )[0] new_validator = validators_registry.create_validator( validator_address, stake, - fixed_provider or details["provider"], - fixed_model or details["model"], - details["config"], + fixed_provider or details.provider, + fixed_model or details.model, + details.config, ) if not "id" in new_validator: raise SystemError("Failed to create Validator") diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py index 3ba86dbf0..3e1bd0239 100644 --- a/tests/unit/test_providers.py +++ b/tests/unit/test_providers.py @@ -12,4 +12,4 @@ def test_default_providers_valid(): def test_create_random_providers(): # Note: testing this is very slow, so we only test it once - create_random_providers(1) + print(create_random_providers(1)) From a19599f5f85e33404c20aaf37440a1f70c39d271 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 5 Sep 2024 14:41:00 -0300 Subject: [PATCH 20/75] fix pytest precommit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 60d947a26..876269f34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: hooks: - id: backend-unit-pytest name: backend unit tests with pytest - entry: python3 -m pytest backend + entry: bash -c "source .venv/bin/activate && pytest tests/unit" language: system types: [python] pass_filenames: false From 3824a7d08f3f608140c9865deeba83d9a9d97a8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 5 Sep 2024 14:41:05 -0300 Subject: [PATCH 21/75] fix requirements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/protocol_rpc/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/protocol_rpc/requirements.txt b/backend/protocol_rpc/requirements.txt index 183337fa4..254f4f1a3 100644 --- a/backend/protocol_rpc/requirements.txt +++ b/backend/protocol_rpc/requirements.txt @@ -19,5 +19,6 @@ eth-account==0.13.1 eth-utils==4.1.1 sentence-transformers==3.0.1 Flask-SQLAlchemy==3.1.1 -jsf==0.11.2 +jsf==0.11.2 jsonschema==4.23.0 +hypothesis_jsonschema==0.23.1 From 69168eda7fc3b79327be33d6e79efb330d6bb28b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 5 Sep 2024 14:52:00 -0300 Subject: [PATCH 22/75] remove gpt 3.5 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../create_nodes/default_providers/openai_gpt-3.5-turbo.json | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json diff --git a/backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json b/backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json deleted file mode 100644 index d645c7cbb..000000000 --- a/backend/node/create_nodes/default_providers/openai_gpt-3.5-turbo.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "provider": "openai", - "model": "gpt-3.5-turbo", - "config": "" -} From edd48ceb86f9ddbcb20873e63883ea7023a216eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 08:17:14 -0300 Subject: [PATCH 23/75] improve backend cmd MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- docker/Dockerfile.backend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend index a4a138cab..91b446a82 100644 --- a/docker/Dockerfile.backend +++ b/docker/Dockerfile.backend @@ -25,7 +25,7 @@ COPY backend $path/backend FROM base AS debug RUN pip install --no-cache-dir debugpy watchdog USER backend-user -CMD watchmedo auto-restart --recursive --pattern="*.py" --ignore-patterns="*.pyc;*__pycache__*" -- python -m debugpy --listen 0.0.0.0:${RPCDEBUGPORT} -m flask run -h 0.0.0.0 -p ${FLASK_SERVER_PORT} +CMD watchmedo auto-restart --no-restart-on-command-exit --recursive --pattern="*.py" --ignore-patterns="*.pyc;*__pycache__*" -- python -m debugpy --listen 0.0.0.0:${RPCDEBUGPORT} -m flask run -h 0.0.0.0 -p ${FLASK_SERVER_PORT} ###########START NEW IMAGE: PRODUCTION ################### FROM base AS prod From 2ee9a2c068f748ec16761d3437de4c6090a18d5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 08:17:30 -0300 Subject: [PATCH 24/75] remove defaults.json MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/create_nodes.py | 99 +++++--------------- backend/node/create_nodes/defaults.json | 104 ---------------------- backend/protocol_rpc/endpoints.py | 3 - 3 files changed, 20 insertions(+), 186 deletions(-) delete mode 100644 backend/node/create_nodes/defaults.json diff --git a/backend/node/create_nodes/create_nodes.py b/backend/node/create_nodes/create_nodes.py index ec63ea20b..993a5b77f 100644 --- a/backend/node/create_nodes/create_nodes.py +++ b/backend/node/create_nodes/create_nodes.py @@ -18,71 +18,13 @@ def base_node_json(provider: str, model: str) -> dict: return {"provider": provider, "model": model, "config": {}} -def get_provider_models( - defaults: dict, provider: str, get_ollama_url: Callable[[str], str] -) -> list: - if provider == "ollama": - ollama_models_result = requests.get(get_ollama_url("tags")).json() - installed_ollama_models = [] - for ollama_model in ollama_models_result["models"]: - # "llama3:latest" => "llama3" - installed_ollama_models.append(ollama_model["name"].split(":")[0]) - return installed_ollama_models - - elif provider == "openai": - return defaults["openai_models"].split(",") - - elif provider == "heuristai": - return defaults["heuristai_models"].split(",") - - else: - raise Exception("Provider (" + provider + ") not found") - - -def get_providers() -> list: - return ["openai", "ollama", "heuristai"] - - -def get_default_config_for_providers_and_nodes() -> dict: - cwd = os.path.abspath(os.getcwd()) - nodes_dir = "/backend/node/create_nodes" - file = open(cwd + nodes_dir + "/defaults.json", "r") - config = json.load(file)[0] - file.close() - return config - - -def get_config_with_specific_providers(config, providers: list) -> dict: - if len(providers) > 0: - default_providers_weights = config["providers"]["provider_weights"] - - # Rebuild the dictionary with only the desired keys - config["providers"]["provider_weights"] = { - provider: weight - for provider, weight in default_providers_weights.items() - if provider in providers - } - return config - - -def get_options(provider, contents): - options = None - for node_default in contents["node_defaults"]: - if node_default["provider"] == provider: - options = node_default["options"] - if not options: - raise Exception(provider + " is not specified in node_defaults") - return options - - -def num_decimal_places(number: float) -> int: - fractional_part = number - int(number) - decimal_places = 0 - while fractional_part != 0: - fractional_part *= 10 - fractional_part -= int(fractional_part) - decimal_places += 1 - return decimal_places +def get_available_ollama_models(get_ollama_url: Callable[[str], str]) -> List[str]: + ollama_models_result = requests.get(get_ollama_url("tags")).json() + installed_ollama_models = [] + for ollama_model in ollama_models_result["models"]: + # "llama3:latest" => "llama3" + installed_ollama_models.append(ollama_model["name"].split(":")[0]) + return installed_ollama_models def random_validator_config( @@ -103,23 +45,20 @@ def random_validator_config( for provider in stored_providers if provider.provider in provider_names ] + # TODO: this methods for checking the providers are decoupled from the actual configuration and schema of the providers. This means that modifications need to be done in two places. - # default_config = get_default_config_for_providers_and_nodes() - # config = get_config_with_specific_providers(default_config, providers) - # ollama_models = get_provider_models({}, "ollama", get_ollama_url) + # TODO: when should we check which models are available? Maybe when filling up the database? Should we check every time since the user can download more models? + available_ollama_models = get_available_ollama_models(get_ollama_url) - # See if they have an the provider's keys. + is_openai_available = not re.match( + default_provider_key_regex, os.environ.get("OPENAIKEY", "") + ) + is_heuristai_available = not re.match( + default_provider_key_regex, os.environ.get("HEURISTAIAPIKEY", "") + ) - # TODO: when should we check which models are available? Maybe when filling up the database? Should we check every time? - # TODO: this methods for checking the providers are decoupled from the actual configuration and schema of the providers. This means that modifications need to be done in two places. - ollama_models = [ - provider.model for provider in provider_names if provider.provider == "ollama" - ] - if ( - not len(ollama_models) - and re.match(default_provider_key_regex, os.environ.get("OPENAIKEY", "")) - and re.match(default_provider_key_regex, os.environ.get("HEURISTAIAPIKEY", "")) - ): + # Check for providers' keys. + if not (available_ollama_models or is_openai_available or is_heuristai_available): raise Exception("No providers avaliable.") # heuristic_models_result = requests.get(os.environ['HEURISTAIMODELSURL']).json() @@ -132,4 +71,6 @@ def random_validator_config( # raise Exception("Provider " + provider + " is not specified in defaults") + # provider = create_random_providers(amount) + return create_random_providers(amount) # TODO: filter by provider and availability diff --git a/backend/node/create_nodes/defaults.json b/backend/node/create_nodes/defaults.json deleted file mode 100644 index d7a743117..000000000 --- a/backend/node/create_nodes/defaults.json +++ /dev/null @@ -1,104 +0,0 @@ -[ - { - "providers": { - "provider_weights": { - "ollama": 0.5, - "openai": 0.5, - "heuristai": 0.5 - }, - "openai_models": "gpt-4,gpt-4o,gpt-4o-mini", - "heuristai_models": "mistralai/mixtral-8x7b-instruct,meta-llama/llama-2-70b-chat,openhermes-2-yi-34b-gptq,dolphin-2.9-llama3-8b", - "chance_of_default_value": 0.5 - }, - "node_defaults": [ - { - "provider": "ollama", - "options": { - "mirostat": { "default": 0, "min": 0, "max": 2, "step": 1 }, - "mirostat_eta": { "default": 0.1, "min": 0, "max": 1, "step": 0.01 }, - "mirostat_tau": { "default": 5, "min": 0, "max": 10, "step": 0.1 }, - "num_ctx": { - "default": 2048, - "min": 512, - "max": 4096, - "step": "512,1024,2048,4096", - "comment": "this needs to be a per model value" - }, - "num_gqa": { "default": 8, "min": 1, "max": 20, "step": 1 }, - "num_gpu": { "default": 0, "min": 0, "max": 16, "step": 1 }, - "num_thread": { "default": 2, "min": 1, "max": 16, "step": 1 }, - "repeat_last_n": { - "default": 64, - "min": 8, - "max": 4096, - "step": "8,16,32,64,128,256,512,1024,2048,4096" - }, - "repeat_penalty": { - "default": 1.1, - "min": 1.0, - "max": 2.0, - "step": 0.1 - }, - "temprature": { "default": 0.8, "min": 0.0, "max": 1.5, "step": 0.1 }, - "seed": { "default": 0, "min": 0, "max": 1000000, "step": 1 }, - "stop": "", - "tfs_z": { "default": 1.0, "min": 1.0, "max": 2.0, "step": 0.1 }, - "num_predict": { - "default": 128, - "min": -2, - "max": 512, - "step": "-2,-1,32,64,128,256,512" - }, - "top_k": { "default": 40, "min": 2, "max": 100, "step": 1 }, - "top_p": { "default": 0.9, "min": 0.5, "max": 0.99, "step": 0.01 } - } - }, - { - "provider": "heuristai", - "options": { - "temperature": { - "default": 0.75, - "min": 0.0, - "max": 1.0, - "step": 0.05 - }, - "max_tokens": { "default": 500, "min": 100, "max": 2000, "step": 10 } - } - }, - { - "provider": "openai", - "options": "" - } - ], - "node_custom": [ - { - "provider": "ollama", - "config": [ - { - "model": "llama3", - "options": { - "num_ctx": 2048 - } - }, - { - "model": "mistral", - "options": { - "mirostat": 0, - "mirostat_eta": 0.2, - "num_ctx": 2048, - "temprature": 0.7, - "num_predict": 64 - } - }, - { - "model": "gemma", - "options": { - "num_ctx": 2048, - "temprature": 0.9 - } - } - ] - } - ] - } -] diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index 71d5c665e..3ac6af31c 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -16,9 +16,6 @@ from backend.database_handler.validators_registry import ValidatorsRegistry from backend.node.create_nodes.create_nodes import ( - get_default_config_for_providers_and_nodes, - get_providers, - get_provider_models, random_validator_config, ) From a370f20f1e4c29fd4a4e26c9a00acff4a3906159 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 11:03:15 -0300 Subject: [PATCH 25/75] refactor get random validator config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../backend_integration_tests_pr.yml | 2 + backend/domain/types.py | 3 + backend/node/create_nodes/create_nodes.py | 80 ++++++++++--------- backend/node/create_nodes/providers.py | 7 +- tests/unit/test_create_nodes.py | 16 ++++ tests/unit/test_providers.py | 5 -- 6 files changed, 69 insertions(+), 44 deletions(-) create mode 100644 tests/unit/test_create_nodes.py diff --git a/.github/workflows/backend_integration_tests_pr.yml b/.github/workflows/backend_integration_tests_pr.yml index 06a420c68..cfd09ed1f 100644 --- a/.github/workflows/backend_integration_tests_pr.yml +++ b/.github/workflows/backend_integration_tests_pr.yml @@ -42,6 +42,8 @@ jobs: - name: Copy .env file run: cp .env.example .env + # TODO: we should also add a heuristai key to the e2e tests + - name: Set OPENAIKEY in the .env file so it can be loaded from the environment env: OPENAIKEY: ${{ secrets.OPENAIKEY }} diff --git a/backend/domain/types.py b/backend/domain/types.py index e01a83fcb..ccd05b047 100644 --- a/backend/domain/types.py +++ b/backend/domain/types.py @@ -10,3 +10,6 @@ class LLMProvider: provider: str model: str config: dict + + def __hash__(self): + return hash((self.provider, self.model, frozenset(self.config.items()))) diff --git a/backend/node/create_nodes/create_nodes.py b/backend/node/create_nodes/create_nodes.py index 993a5b77f..5cb6f1f40 100644 --- a/backend/node/create_nodes/create_nodes.py +++ b/backend/node/create_nodes/create_nodes.py @@ -1,21 +1,19 @@ import os -import json import re +import secrets from typing import Callable, List import requests +from numpy.random import default_rng from dotenv import load_dotenv from backend.domain.types import LLMProvider -from backend.node.create_nodes.providers import create_random_providers load_dotenv() +rng = default_rng(secrets.randbits(128)) -default_provider_key_regex = r"^(|)$" - - -def base_node_json(provider: str, model: str) -> dict: - return {"provider": provider, "model": model, "config": {}} +empty_provider_key_regex = r"^(|)$" +provider_key_names_suffix = ["_API_KEY", "KEY", "APIKEY"] def get_available_ollama_models(get_ollama_url: Callable[[str], str]) -> List[str]: @@ -28,37 +26,54 @@ def get_available_ollama_models(get_ollama_url: Callable[[str], str]) -> List[st def random_validator_config( - get_ollama_url: Callable[[str], str], - # get_stored_providers: Callable[[], List[LLMProvider]], - provider_names: List[str] = None, + get_available_ollama_models: Callable[[], str], + get_stored_providers: Callable[[], List[LLMProvider]], + provider_names: set[str] = None, amount: int = 1, + environ: dict = os.environ, ) -> List[LLMProvider]: - provider_names = provider_names or [] + providers_to_use = get_stored_providers() - # stored_providers = get_stored_providers() - stored_providers = [] - providers_to_use = stored_providers - - if len(provider_names) > 0: + if provider_names: providers_to_use = [ provider - for provider in stored_providers + for provider in providers_to_use if provider.provider in provider_names ] - # TODO: this methods for checking the providers are decoupled from the actual configuration and schema of the providers. This means that modifications need to be done in two places. + # stored_providers_to_use + + if not providers_to_use: + raise ValueError( + f"Requested providers '{provider_names}' do not match any stored providers. Please review your stored providers." + ) + + # Ollama is the only provider which is not OpenAI compliant, thus it gets its custom logic + # To add more non-OpenAI compliant providers, we'll need to add more custom logic here or refactor the provider's schema to allow general configurations + available_ollama_models = get_available_ollama_models() + + providers_to_use = [ + provider + for provider in providers_to_use + if provider.model in available_ollama_models + ] + + def filter_by_available_key(provider: LLMProvider) -> bool: + if provider.provider == "ollama": + return True + provider_key_names = [ + provider.provider.upper() + suffix for suffix in provider_key_names_suffix + ] + for provider_key_name in provider_key_names: + if not re.match( + empty_provider_key_regex, environ.get(provider_key_name, "") + ): + return True - # TODO: when should we check which models are available? Maybe when filling up the database? Should we check every time since the user can download more models? - available_ollama_models = get_available_ollama_models(get_ollama_url) + return False - is_openai_available = not re.match( - default_provider_key_regex, os.environ.get("OPENAIKEY", "") - ) - is_heuristai_available = not re.match( - default_provider_key_regex, os.environ.get("HEURISTAIAPIKEY", "") - ) + providers_to_use = list(filter(filter_by_available_key, providers_to_use)) - # Check for providers' keys. - if not (available_ollama_models or is_openai_available or is_heuristai_available): + if not providers_to_use: raise Exception("No providers avaliable.") # heuristic_models_result = requests.get(os.environ['HEURISTAIMODELSURL']).json() @@ -66,11 +81,4 @@ def random_validator_config( # for entry in heuristic_models_result: # heuristic_models.append(entry['name']) - # provider = get_random_provider_using_weights(config["providers"], get_ollama_url) - # options = get_options(provider, config) - - # raise Exception("Provider " + provider + " is not specified in defaults") - - # provider = create_random_providers(amount) - - return create_random_providers(amount) # TODO: filter by provider and availability + return list(rng.choice(providers_to_use, amount)) diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index dc5cf87b6..302bc1955 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -2,9 +2,6 @@ import os from typing import List -from hypothesis import HealthCheck, given, settings -from hypothesis.errors import HypothesisDeprecationWarning -from hypothesis_jsonschema import from_schema from jsonschema import Draft202012Validator, validate from backend.domain.types import LLMProvider @@ -60,9 +57,13 @@ def _to_domain(provider: dict) -> LLMProvider: def create_random_providers(amount: int) -> list[LLMProvider]: """ + Not being used at the moment, left here for future reference. Creates random providers deriving them from the json schema. Internally uses hypothesis to generate the data, which is hacky since it's meant to be a testing library. """ + from hypothesis import HealthCheck, given, settings + from hypothesis.errors import HypothesisDeprecationWarning + from hypothesis_jsonschema import from_schema import warnings return_value = [] diff --git a/tests/unit/test_create_nodes.py b/tests/unit/test_create_nodes.py new file mode 100644 index 000000000..0332c95c0 --- /dev/null +++ b/tests/unit/test_create_nodes.py @@ -0,0 +1,16 @@ +from backend.domain.types import LLMProvider +from backend.node.create_nodes.create_nodes import random_validator_config + + +def test_random_validator_config(): + stored_providers = [LLMProvider(provider="ollama", model="llama3", config={})] + get_stored_providers = lambda: stored_providers + + get_available_ollama_models = lambda: ["llama3"] + + result = random_validator_config( + get_available_ollama_models, + get_stored_providers, + ) + + assert set(result).issubset(set(stored_providers)) diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py index 3e1bd0239..b2e7ee1e3 100644 --- a/tests/unit/test_providers.py +++ b/tests/unit/test_providers.py @@ -8,8 +8,3 @@ def test_default_providers_valid(): providers = get_default_providers() assert len(providers) > 0 - - -def test_create_random_providers(): - # Note: testing this is very slow, so we only test it once - print(create_random_providers(1)) From 90f47e9541d8613caad2c1bbf7a305e230b108dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 11:26:24 -0300 Subject: [PATCH 26/75] add tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/create_nodes.py | 2 +- tests/unit/test_create_nodes.py | 43 ++++++++++++++++++++--- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/backend/node/create_nodes/create_nodes.py b/backend/node/create_nodes/create_nodes.py index 5cb6f1f40..09874cb7f 100644 --- a/backend/node/create_nodes/create_nodes.py +++ b/backend/node/create_nodes/create_nodes.py @@ -30,7 +30,7 @@ def random_validator_config( get_stored_providers: Callable[[], List[LLMProvider]], provider_names: set[str] = None, amount: int = 1, - environ: dict = os.environ, + environ: dict[str, str] = os.environ, ) -> List[LLMProvider]: providers_to_use = get_stored_providers() diff --git a/tests/unit/test_create_nodes.py b/tests/unit/test_create_nodes.py index 0332c95c0..2739280d6 100644 --- a/tests/unit/test_create_nodes.py +++ b/tests/unit/test_create_nodes.py @@ -1,16 +1,51 @@ +import pytest from backend.domain.types import LLMProvider from backend.node.create_nodes.create_nodes import random_validator_config -def test_random_validator_config(): - stored_providers = [LLMProvider(provider="ollama", model="llama3", config={})] +@pytest.mark.parametrize( + "available_ollama_models,stored_providers,provider_names,amount,environ,expected", + [ + pytest.param( + ["llama3"], + [LLMProvider(provider="ollama", model="llama3", config={})], + None, + 1, + {}, + [LLMProvider(provider="ollama", model="llama3", config={})], + id="only ollama", + ), + pytest.param( + ["llama3", "llama3.1"], + [ + LLMProvider(provider="ollama", model="llama3.1", config={}), + LLMProvider(provider="openai", model="gpt-4", config={}), + LLMProvider(provider="openai", model="gpt-4o", config={}), + LLMProvider( + provider="heuristai", model="meta-llama/llama-2-70b-chat", config={} + ), + ], + None, + 1, + {"OPENAIKEY": ""}, + [LLMProvider(provider="ollama", model="llama3.1", config={})], + id="only ollama available", + ), + ], +) +def test_random_validator_config( + available_ollama_models, stored_providers, provider_names, amount, environ, expected +): get_stored_providers = lambda: stored_providers - get_available_ollama_models = lambda: ["llama3"] + get_available_ollama_models = lambda: available_ollama_models result = random_validator_config( get_available_ollama_models, get_stored_providers, + provider_names, + amount, + environ, ) - assert set(result).issubset(set(stored_providers)) + assert expected == result From e1b561af8403a29f4b987e53f46ba95625141266 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 11:39:14 -0300 Subject: [PATCH 27/75] add more tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/create_nodes.py | 2 +- tests/unit/test_create_nodes.py | 66 ++++++++++++++++++++--- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/backend/node/create_nodes/create_nodes.py b/backend/node/create_nodes/create_nodes.py index 09874cb7f..48f805ecc 100644 --- a/backend/node/create_nodes/create_nodes.py +++ b/backend/node/create_nodes/create_nodes.py @@ -54,7 +54,7 @@ def random_validator_config( providers_to_use = [ provider for provider in providers_to_use - if provider.model in available_ollama_models + if provider.provider != "ollama" or provider.model in available_ollama_models ] def filter_by_available_key(provider: LLMProvider) -> bool: diff --git a/tests/unit/test_create_nodes.py b/tests/unit/test_create_nodes.py index 2739280d6..43e598750 100644 --- a/tests/unit/test_create_nodes.py +++ b/tests/unit/test_create_nodes.py @@ -10,7 +10,7 @@ ["llama3"], [LLMProvider(provider="ollama", model="llama3", config={})], None, - 1, + 10, {}, [LLMProvider(provider="ollama", model="llama3", config={})], id="only ollama", @@ -21,16 +21,68 @@ LLMProvider(provider="ollama", model="llama3.1", config={}), LLMProvider(provider="openai", model="gpt-4", config={}), LLMProvider(provider="openai", model="gpt-4o", config={}), - LLMProvider( - provider="heuristai", model="meta-llama/llama-2-70b-chat", config={} - ), + LLMProvider(provider="heuristai", model="", config={}), ], None, - 1, - {"OPENAIKEY": ""}, + 10, + {"OPENAI_API_KEY": ""}, [LLMProvider(provider="ollama", model="llama3.1", config={})], id="only ollama available", ), + pytest.param( + ["llama3", "llama3.1"], + [ + LLMProvider(provider="openai", model="gpt-4", config={}), + LLMProvider(provider="openai", model="gpt-4o", config={}), + LLMProvider(provider="heuristai", model="", config={}), + ], + None, + 10, + {"OPENAIKEY": "filled"}, + [ + LLMProvider(provider="openai", model="gpt-4", config={}), + LLMProvider(provider="openai", model="gpt-4o", config={}), + ], + id="only openai available", + ), + pytest.param( + ["llama3", "llama3.1"], + [ + LLMProvider(provider="openai", model="gpt-4", config={}), + LLMProvider(provider="openai", model="gpt-4o", config={}), + LLMProvider(provider="heuristai", model="a", config={}), + LLMProvider(provider="heuristai", model="b", config={}), + ], + None, + 10, + {"OPENAI_API_KEY": "", "HEURISTAI_API_KEY": "filled"}, + [ + LLMProvider(provider="heuristai", model="a", config={}), + LLMProvider(provider="heuristai", model="b", config={}), + ], + id="only heuristai", + ), + pytest.param( + ["llama3", "llama3.1"], + [ + LLMProvider(provider="ollama", model="llama3.1", config={}), + LLMProvider(provider="openai", model="gpt-4", config={}), + LLMProvider(provider="openai", model="gpt-4o", config={}), + LLMProvider(provider="heuristai", model="a", config={}), + LLMProvider(provider="heuristai", model="b", config={}), + ], + None, + 10, + {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, + [ + LLMProvider(provider="ollama", model="llama3.1", config={}), + LLMProvider(provider="openai", model="gpt-4", config={}), + LLMProvider(provider="openai", model="gpt-4o", config={}), + LLMProvider(provider="heuristai", model="a", config={}), + LLMProvider(provider="heuristai", model="b", config={}), + ], + id="all available", + ), ], ) def test_random_validator_config( @@ -48,4 +100,4 @@ def test_random_validator_config( environ, ) - assert expected == result + assert set(result).issubset(set(expected)) From 8cb8b0fc16ae3313a70e816bdea1755193ed6f4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 11:48:13 -0300 Subject: [PATCH 28/75] add tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- tests/unit/test_create_nodes.py | 51 +++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_create_nodes.py b/tests/unit/test_create_nodes.py index 43e598750..eecdfbf71 100644 --- a/tests/unit/test_create_nodes.py +++ b/tests/unit/test_create_nodes.py @@ -88,16 +88,55 @@ def test_random_validator_config( available_ollama_models, stored_providers, provider_names, amount, environ, expected ): - get_stored_providers = lambda: stored_providers - - get_available_ollama_models = lambda: available_ollama_models - result = random_validator_config( - get_available_ollama_models, - get_stored_providers, + lambda: available_ollama_models, + lambda: stored_providers, provider_names, amount, environ, ) assert set(result).issubset(set(expected)) + + +@pytest.mark.parametrize( + "available_ollama_models,stored_providers,provider_names,amount,environ,exception", + [ + pytest.param( + [], + [LLMProvider(provider="ollama", model="llama3", config={})], + ["heuristai", "openai"], + 10, + {}, + ValueError, + id="no match", + ), + pytest.param( + [], + [LLMProvider(provider="ollama", model="llama3", config={})], + ["ollama"], + 10, + {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, + Exception, + id="no intersection", + ), + ], +) +def test_random_validator_config_fail( + available_ollama_models, + stored_providers, + provider_names, + amount, + environ, + exception, +): + with pytest.raises(exception): + random_validator_config( + result=random_validator_config( + lambda: available_ollama_models, + lambda: stored_providers, + provider_names, + amount, + environ, + ) + ) From da5f1153a2f295de328b018ba174b0a733055977 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 12:08:48 -0300 Subject: [PATCH 29/75] refactor to use random_validator_config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/domain/types.py | 9 +++++ backend/node/create_nodes/create_nodes.py | 15 +++++--- backend/protocol_rpc/configuration.py | 9 +++++ backend/protocol_rpc/endpoints.py | 42 ++++++++++++++--------- tests/unit/test_create_nodes.py | 38 ++++++++++++++------ tests/unit/test_providers.py | 1 - 6 files changed, 82 insertions(+), 32 deletions(-) diff --git a/backend/domain/types.py b/backend/domain/types.py index ccd05b047..306092b0c 100644 --- a/backend/domain/types.py +++ b/backend/domain/types.py @@ -13,3 +13,12 @@ class LLMProvider: def __hash__(self): return hash((self.provider, self.model, frozenset(self.config.items()))) + + +@dataclass() +class Validator: + address: str + stake: int + provider: str + model: str + config: dict diff --git a/backend/node/create_nodes/create_nodes.py b/backend/node/create_nodes/create_nodes.py index 48f805ecc..91c97efe2 100644 --- a/backend/node/create_nodes/create_nodes.py +++ b/backend/node/create_nodes/create_nodes.py @@ -28,23 +28,28 @@ def get_available_ollama_models(get_ollama_url: Callable[[str], str]) -> List[st def random_validator_config( get_available_ollama_models: Callable[[], str], get_stored_providers: Callable[[], List[LLMProvider]], - provider_names: set[str] = None, + limit_providers: set[str] = None, + limit_models: set[str] = None, amount: int = 1, environ: dict[str, str] = os.environ, ) -> List[LLMProvider]: providers_to_use = get_stored_providers() - if provider_names: + if limit_providers: providers_to_use = [ provider for provider in providers_to_use - if provider.provider in provider_names + if provider.provider in limit_providers + ] + + if limit_models: + providers_to_use = [ + provider for provider in providers_to_use if provider.model in limit_models ] - # stored_providers_to_use if not providers_to_use: raise ValueError( - f"Requested providers '{provider_names}' do not match any stored providers. Please review your stored providers." + f"Requested providers '{limit_providers}' do not match any stored providers. Please review your stored providers." ) # Ollama is the only provider which is not OpenAI compliant, thus it gets its custom logic diff --git a/backend/protocol_rpc/configuration.py b/backend/protocol_rpc/configuration.py index 2f861b7da..eed7e9cc3 100644 --- a/backend/protocol_rpc/configuration.py +++ b/backend/protocol_rpc/configuration.py @@ -1,6 +1,8 @@ import os import json +import requests + class GlobalConfiguration: @staticmethod @@ -9,3 +11,10 @@ def get_ollama_url(endpoint: str) -> str: def get_disabled_info_logs_endpoints(self) -> list: return json.loads(os.environ.get("DISABLE_INFO_LOGS_ENDPOINTS", "[]")) + + def get_available_ollama_models(self) -> list: + ollama_models_result = requests.get(self.get_ollama_url("tags")).json() + installed_ollama_models = [] + for ollama_model in ollama_models_result["models"]: + installed_ollama_models.append(ollama_model["name"].split(":")[0]) + return installed_ollama_models diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index 3ac6af31c..9581626c2 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -174,11 +174,16 @@ def create_validator( def create_random_validator( validators_registry: ValidatorsRegistry, accounts_manager: AccountsManager, + llm_provider_registry: LLMProviderRegistry, config: GlobalConfiguration, stake: int, ) -> dict: validator_address = accounts_manager.create_new_account().address - details = random_validator_config(config.get_ollama_url)[0] + details = random_validator_config( + config.get_available_ollama_models, + llm_provider_registry.get_all, + 1, + )[0] response = validators_registry.create_validator( validator_address, stake, @@ -194,31 +199,36 @@ def create_random_validator( def create_random_validators( validators_registry: ValidatorsRegistry, accounts_manager: AccountsManager, + llm_provider_registry: LLMProviderRegistry, config: GlobalConfiguration, count: int, min_stake: int, max_stake: int, - providers: list = None, - fixed_provider: str = None, - fixed_model: str = None, + limit_providers: list[str] = None, + limit_models: list[str] = None, ) -> dict: - providers = providers or [] + limit_providers = limit_providers or [] + limit_models = limit_models or [] + + details = random_validator_config( + config.get_available_ollama_models, + llm_provider_registry.get_all, + limit_providers=set(limit_providers), + limit_models=set(limit_models), + amount=count, + ) - for _ in range(count): - stake = random.uniform(min_stake, max_stake) + for detail in details: + stake = random.randint(min_stake, max_stake) validator_address = accounts_manager.create_new_account().address - details = random_validator_config( - config.get_ollama_url, provider_names=providers - )[0] - new_validator = validators_registry.create_validator( + + validators_registry.create_validator( validator_address, stake, - fixed_provider or details.provider, - fixed_model or details.model, - details.config, + detail.provider, + detail.model, + detail.config, ) - if not "id" in new_validator: - raise SystemError("Failed to create Validator") response = validators_registry.get_all_validators() return response diff --git a/tests/unit/test_create_nodes.py b/tests/unit/test_create_nodes.py index eecdfbf71..622028475 100644 --- a/tests/unit/test_create_nodes.py +++ b/tests/unit/test_create_nodes.py @@ -4,12 +4,13 @@ @pytest.mark.parametrize( - "available_ollama_models,stored_providers,provider_names,amount,environ,expected", + "available_ollama_models,stored_providers,limit_providers,limit_models,amount,environ,expected", [ pytest.param( ["llama3"], [LLMProvider(provider="ollama", model="llama3", config={})], None, + None, 10, {}, [LLMProvider(provider="ollama", model="llama3", config={})], @@ -24,6 +25,7 @@ LLMProvider(provider="heuristai", model="", config={}), ], None, + None, 10, {"OPENAI_API_KEY": ""}, [LLMProvider(provider="ollama", model="llama3.1", config={})], @@ -32,10 +34,14 @@ pytest.param( ["llama3", "llama3.1"], [ + LLMProvider(provider="ollama", model="llama3", config={}), LLMProvider(provider="openai", model="gpt-4", config={}), LLMProvider(provider="openai", model="gpt-4o", config={}), LLMProvider(provider="heuristai", model="", config={}), + LLMProvider(provider="heuristai", model="a", config={}), + LLMProvider(provider="heuristai", model="b", config={}), ], + ["openai"], None, 10, {"OPENAIKEY": "filled"}, @@ -43,7 +49,7 @@ LLMProvider(provider="openai", model="gpt-4", config={}), LLMProvider(provider="openai", model="gpt-4o", config={}), ], - id="only openai available", + id="only openai", ), pytest.param( ["llama3", "llama3.1"], @@ -53,12 +59,12 @@ LLMProvider(provider="heuristai", model="a", config={}), LLMProvider(provider="heuristai", model="b", config={}), ], - None, + ["heuristai"], + ["a"], 10, - {"OPENAI_API_KEY": "", "HEURISTAI_API_KEY": "filled"}, + {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, [ LLMProvider(provider="heuristai", model="a", config={}), - LLMProvider(provider="heuristai", model="b", config={}), ], id="only heuristai", ), @@ -72,6 +78,7 @@ LLMProvider(provider="heuristai", model="b", config={}), ], None, + None, 10, {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, [ @@ -86,12 +93,19 @@ ], ) def test_random_validator_config( - available_ollama_models, stored_providers, provider_names, amount, environ, expected + available_ollama_models, + stored_providers, + limit_providers, + limit_models, + amount, + environ, + expected, ): result = random_validator_config( lambda: available_ollama_models, lambda: stored_providers, - provider_names, + limit_providers, + limit_models, amount, environ, ) @@ -100,12 +114,13 @@ def test_random_validator_config( @pytest.mark.parametrize( - "available_ollama_models,stored_providers,provider_names,amount,environ,exception", + "available_ollama_models,stored_providers,limit_providers,limit_models,amount,environ,exception", [ pytest.param( [], [LLMProvider(provider="ollama", model="llama3", config={})], ["heuristai", "openai"], + None, 10, {}, ValueError, @@ -115,6 +130,7 @@ def test_random_validator_config( [], [LLMProvider(provider="ollama", model="llama3", config={})], ["ollama"], + None, 10, {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, Exception, @@ -125,7 +141,8 @@ def test_random_validator_config( def test_random_validator_config_fail( available_ollama_models, stored_providers, - provider_names, + limit_providers, + limit_models, amount, environ, exception, @@ -135,7 +152,8 @@ def test_random_validator_config_fail( result=random_validator_config( lambda: available_ollama_models, lambda: stored_providers, - provider_names, + limit_providers, + limit_models, amount, environ, ) diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py index b2e7ee1e3..a71859d0d 100644 --- a/tests/unit/test_providers.py +++ b/tests/unit/test_providers.py @@ -1,5 +1,4 @@ from backend.node.create_nodes.providers import ( - create_random_providers, get_default_providers, ) From 26c1ca1dd02400f3d6962fd562b66782b911ce0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 12:35:31 -0300 Subject: [PATCH 30/75] configure endpoints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/database_handler/llm_providers.py | 8 ++++- backend/protocol_rpc/endpoints.py | 41 ++++++++++++----------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/backend/database_handler/llm_providers.py b/backend/database_handler/llm_providers.py index 9d2256a11..9e4ad4b39 100644 --- a/backend/database_handler/llm_providers.py +++ b/backend/database_handler/llm_providers.py @@ -33,7 +33,13 @@ def add(self, provider: LLMProvider) -> int: def edit(self, id: int, provider: LLMProvider): self.session.query(LLMProviderDBModel).filter( LLMProviderDBModel.id == id - ).update(_to_db_model(provider)) + ).update( + { + LLMProviderDBModel.provider: provider.provider, + LLMProviderDBModel.model: provider.model, + LLMProviderDBModel.config: provider.config, + } + ) self.session.commit() def delete(self, id: int): diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index 9581626c2..af39aabba 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -178,24 +178,17 @@ def create_random_validator( config: GlobalConfiguration, stake: int, ) -> dict: - validator_address = accounts_manager.create_new_account().address - details = random_validator_config( - config.get_available_ollama_models, - llm_provider_registry.get_all, + return create_random_validators( + validators_registry, + accounts_manager, + llm_provider_registry, + config, 1, - )[0] - response = validators_registry.create_validator( - validator_address, stake, - details.provider, - details.model, - details.config, - ) - return response + stake, + )[0] -# TODO: Refactor this function to put the random config generator inside the domain -# and reuse the generate single random validator function def create_random_validators( validators_registry: ValidatorsRegistry, accounts_manager: AccountsManager, @@ -206,7 +199,7 @@ def create_random_validators( max_stake: int, limit_providers: list[str] = None, limit_models: list[str] = None, -) -> dict: +) -> dict: # TODO: should return list limit_providers = limit_providers or [] limit_models = limit_models or [] @@ -218,18 +211,20 @@ def create_random_validators( amount=count, ) + response = [] for detail in details: stake = random.randint(min_stake, max_stake) validator_address = accounts_manager.create_new_account().address - validators_registry.create_validator( + validator = validators_registry.create_validator( validator_address, stake, detail.provider, detail.model, detail.config, ) - response = validators_registry.get_all_validators() + response.append(validator) + return response @@ -452,10 +447,18 @@ def register_all_rpc_endpoints( create_validator, validators_registry, accounts_manager ) register_rpc_endpoint_for_partial( - create_random_validator, validators_registry, accounts_manager, config + create_random_validator, + validators_registry, + accounts_manager, + llm_provider_registry, + config, ) register_rpc_endpoint_for_partial( - create_random_validators, validators_registry, accounts_manager, config + create_random_validators, + validators_registry, + accounts_manager, + llm_provider_registry, + config, ) register_rpc_endpoint_for_partial( update_validator, validators_registry, accounts_manager From bdccdf8ae1a2d761e88fcc22104c390bf5410a7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 14:33:30 -0300 Subject: [PATCH 31/75] update e2e tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- tests/integration/conftest.py | 19 +++++++++++++++++++ .../test_football_prediction_market.py | 13 +------------ .../contract_examples/test_llm_erc20.py | 13 +------------ .../contract_examples/test_log_indexer.py | 8 +------- .../contract_examples/test_storage.py | 13 +------------ .../contract_examples/test_user_storage.py | 13 +------------ .../contract_examples/test_wizard_of_coin.py | 12 +----------- 7 files changed, 25 insertions(+), 66 deletions(-) create mode 100644 tests/integration/conftest.py diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 000000000..53cdef0cb --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,19 @@ +import pytest + +from tests.common.request import payload, post_request_localhost +from tests.common.response import has_success_status + + +@pytest.fixture +def setup_validators(): + result = post_request_localhost( + payload("create_random_validators", 5, 8, 12, ["openai"], ["gpt-4o"]) + ).json() + assert has_success_status(result) + + yield + + delete_validators_result = post_request_localhost( + payload("delete_all_validators") + ).json() + assert has_success_status(delete_validators_result) diff --git a/tests/integration/contract_examples/test_football_prediction_market.py b/tests/integration/contract_examples/test_football_prediction_market.py index b8045fb69..ce358065d 100644 --- a/tests/integration/contract_examples/test_football_prediction_market.py +++ b/tests/integration/contract_examples/test_football_prediction_market.py @@ -22,13 +22,7 @@ from tests.common.accounts import create_new_account -def test_football_prediction_market(): - # Validators Setup - result = post_request_localhost( - payload("create_random_validators", 5, 8, 12, ["openai"], None, "gpt-4o") - ).json() - assert has_success_status(result) - +def test_football_prediction_market(setup_validators): # Account Setup from_account = create_new_account() @@ -64,8 +58,3 @@ def test_football_prediction_market(): # Assert response format assert_dict_struct(transaction_response_call_1, call_contract_function_response) - - delete_validators_result = post_request_localhost( - payload("delete_all_validators") - ).json() - assert has_success_status(delete_validators_result) diff --git a/tests/integration/contract_examples/test_llm_erc20.py b/tests/integration/contract_examples/test_llm_erc20.py index bae19e2ee..d65811728 100644 --- a/tests/integration/contract_examples/test_llm_erc20.py +++ b/tests/integration/contract_examples/test_llm_erc20.py @@ -25,13 +25,7 @@ TRANSFER_AMOUNT = 100 -def test_llm_erc20(): - # Validators Setup - result = post_request_localhost( - payload("create_random_validators", 5, 8, 12, ["openai"], None, "gpt-4o") - ).json() - assert has_success_status(result) - +def test_llm_erc20(setup_validators): # Account Setup from_account_a = create_new_account() from_account_b = create_new_account() @@ -106,8 +100,3 @@ def test_llm_erc20(): ) assert has_success_status(contract_state_2_3) assert contract_state_2_3["result"]["data"] == TRANSFER_AMOUNT - - delete_validators_result = post_request_localhost( - payload("delete_all_validators") - ).json() - assert has_success_status(delete_validators_result) diff --git a/tests/integration/contract_examples/test_log_indexer.py b/tests/integration/contract_examples/test_log_indexer.py index b56c75e88..4c1ff863b 100644 --- a/tests/integration/contract_examples/test_log_indexer.py +++ b/tests/integration/contract_examples/test_log_indexer.py @@ -25,13 +25,7 @@ TRANSFER_AMOUNT = 100 -def test_log_indexer(): - # Validators Setup - result = post_request_localhost( - payload("create_random_validators", 5, 8, 12, ["openai"], None, "gpt-4o-mini") - ).json() - assert has_success_status(result) - +def test_log_indexer(setup_validators): # Account Setup from_account = create_new_account() diff --git a/tests/integration/contract_examples/test_storage.py b/tests/integration/contract_examples/test_storage.py index dd1f5786f..e078fe3e5 100644 --- a/tests/integration/contract_examples/test_storage.py +++ b/tests/integration/contract_examples/test_storage.py @@ -25,13 +25,7 @@ UPDATED_STATE = "b" -def test_storage(): - # Validators Setup - result = post_request_localhost( - payload("create_random_validators", 10, 8, 12, ["openai"], None, "gpt-4o-mini") - ).json() - assert has_success_status(result) - +def test_storage(setup_validators): # Account Setup from_account = create_new_account() @@ -76,8 +70,3 @@ def test_storage(): ) assert has_success_status(contract_state_2) assert contract_state_2["result"]["data"] == UPDATED_STATE - - delete_validators_result = post_request_localhost( - payload("delete_all_validators") - ).json() - assert has_success_status(delete_validators_result) diff --git a/tests/integration/contract_examples/test_user_storage.py b/tests/integration/contract_examples/test_user_storage.py index ca0ee4da2..9b959a6c4 100644 --- a/tests/integration/contract_examples/test_user_storage.py +++ b/tests/integration/contract_examples/test_user_storage.py @@ -28,13 +28,7 @@ UPDATED_STATE_USER_B = "user_b_updated_state" -def test_user_storage(): - # Validators Setup - result = post_request_localhost( - payload("create_random_validators", 10, 8, 12, ["openai"], None, "gpt-4o-mini") - ).json() - assert has_success_status(result) - +def test_user_storage(setup_validators): # Account Setup from_account_a = create_new_account() from_account_b = create_new_account() @@ -155,8 +149,3 @@ def test_user_storage(): ) assert has_success_status(contract_state_4_2) assert contract_state_4_2["result"]["data"] == INITIAL_STATE_USER_B - - delete_validators_result = post_request_localhost( - payload("delete_all_validators") - ).json() - assert has_success_status(delete_validators_result) diff --git a/tests/integration/contract_examples/test_wizard_of_coin.py b/tests/integration/contract_examples/test_wizard_of_coin.py index 18bf75caa..a2a06b137 100644 --- a/tests/integration/contract_examples/test_wizard_of_coin.py +++ b/tests/integration/contract_examples/test_wizard_of_coin.py @@ -22,13 +22,8 @@ from tests.common.accounts import create_new_account -def test_wizard_of_coin(): +def test_wizard_of_coin(setup_validators): print("test_wizard_of_coin") - # Validators - result = post_request_localhost( - payload("create_random_validators", 10, 8, 12, ["openai"], None, "gpt-4o-mini") - ).json() - assert has_success_status(result) # Account Setup from_account = create_new_account() @@ -61,8 +56,3 @@ def test_wizard_of_coin(): # Assert format assert_dict_struct(transaction_response_call_1, call_contract_function_response) - - delete_validators_result = post_request_localhost( - payload("delete_all_validators") - ).json() - assert has_success_status(delete_validators_result) From e20311d8953d5579e364540ec2d6eba428d421ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 14:47:32 -0300 Subject: [PATCH 32/75] add validators tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- tests/integration/test_validators.py | 44 ++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 tests/integration/test_validators.py diff --git a/tests/integration/test_validators.py b/tests/integration/test_validators.py new file mode 100644 index 000000000..eda4466bb --- /dev/null +++ b/tests/integration/test_validators.py @@ -0,0 +1,44 @@ +from tests.common.request import payload, post_request_localhost +from tests.common.response import has_success_status + + +def test_validators(): + delete_validators_result = post_request_localhost( + payload("delete_all_validators") + ).json() + assert has_success_status(delete_validators_result) + + response = post_request_localhost(payload("create_random_validator", 4)).json() + assert has_success_status(response) + + validator = response["result"]["data"] + first_address = validator["address"] + + # Duplicate validator + response = post_request_localhost( + payload( + "create_validator", + validator["stake"], + validator["provider"], + validator["model"], + validator["config"], + ) + ).json() + assert has_success_status(response) + + second_address = response["result"]["data"]["address"] + + # Delete both validators + response = post_request_localhost(payload("delete_validator", first_address)).json() + assert has_success_status(response) + + response = post_request_localhost( + payload("delete_validator", second_address) + ).json() + assert has_success_status(response) + + # Check no validators are left + + response = post_request_localhost(payload("get_all_validators")).json() + assert has_success_status(response) + assert response["result"]["data"] == [] From 6339de26a5e5f372a238af74ff716a9ba391fa8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 19:14:56 -0300 Subject: [PATCH 33/75] add llm provider endpoint test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/database_handler/llm_providers.py | 14 +++++++- backend/node/create_nodes/providers.py | 5 ++- backend/protocol_rpc/endpoints.py | 24 ++++++++++--- .../test_llm_providers_registry.py | 34 +++++++++++++++++++ 4 files changed, 71 insertions(+), 6 deletions(-) create mode 100644 tests/integration/test_llm_providers_registry.py diff --git a/backend/database_handler/llm_providers.py b/backend/database_handler/llm_providers.py index 9e4ad4b39..ec01afbbe 100644 --- a/backend/database_handler/llm_providers.py +++ b/backend/database_handler/llm_providers.py @@ -8,6 +8,7 @@ class LLMProviderRegistry: def __init__(self, session: Session): self.session = session + # TODO: we should call this to fill up the database with the default providers def reset_defaults(self): """Reset all providers to their default values.""" self.session.query(LLMProviderDBModel).delete() @@ -24,13 +25,24 @@ def get_all(self) -> list[LLMProvider]: for provider in self.session.query(LLMProviderDBModel).all() ] + def get_all_dict(self) -> list[dict]: + return [ + { + "id": provider.id, + "provider": provider.provider, + "model": provider.model, + "config": provider.config, + } + for provider in self.session.query(LLMProviderDBModel).all() + ] + def add(self, provider: LLMProvider) -> int: model = _to_db_model(provider) self.session.add(model) self.session.commit() return model.id - def edit(self, id: int, provider: LLMProvider): + def update(self, id: int, provider: LLMProvider): self.session.query(LLMProviderDBModel).filter( LLMProviderDBModel.id == id ).update( diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index 302bc1955..28f8106c4 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -21,7 +21,10 @@ def get_schema() -> dict: def validate_provider(provider: LLMProvider): schema = get_schema() - validate(instance=provider.__dict__, schema=schema) + try: + validate(instance=provider.__dict__, schema=schema) + except Exception as e: + raise ValueError(f"Error validating provider: {e}") def get_default_providers() -> List[LLMProvider]: diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index af39aabba..c16505bca 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -126,7 +126,7 @@ def reset_defaults_llm_providers(llm_provider_registry: LLMProviderRegistry) -> def get_providers_and_models(llm_provider_registry: LLMProviderRegistry) -> dict: - return llm_provider_registry.get_all() + return llm_provider_registry.get_all_dict() def add_provider(llm_provider_registry: LLMProviderRegistry, params: dict) -> dict: @@ -140,7 +140,7 @@ def add_provider(llm_provider_registry: LLMProviderRegistry, params: dict) -> di return llm_provider_registry.add(provider) -def edit_provider( +def update_provider( llm_provider_registry: LLMProviderRegistry, id: int, params: dict ) -> dict: provider = LLMProvider( @@ -150,7 +150,7 @@ def edit_provider( ) validate_provider(provider) - llm_provider_registry.edit(id, provider) + llm_provider_registry.update(id, provider) def delete_provider(llm_provider_registry: LLMProviderRegistry, id: int) -> dict: @@ -165,6 +165,15 @@ def create_validator( model: str, config: json, ) -> dict: + + validate_provider( + LLMProvider( + provider=provider, + model=model, + config=config, + ) + ) + new_address = accounts_manager.create_new_account().address return validators_registry.create_validator( new_address, stake, provider, model, config @@ -240,6 +249,13 @@ def update_validator( # Remove validation while adding migration to update the db address # if not accounts_manager.is_valid_address(validator_address): # raise InvalidAddressError(validator_address) + validate_provider( + LLMProvider( + provider=provider, + model=model, + config=config, + ) + ) return validators_registry.update_validator( validator_address, stake, provider, model, config ) @@ -441,7 +457,7 @@ def register_all_rpc_endpoints( reset_defaults_llm_providers, llm_provider_registry ) register_rpc_endpoint_for_partial(add_provider, llm_provider_registry) - register_rpc_endpoint_for_partial(edit_provider, llm_provider_registry) + register_rpc_endpoint_for_partial(update_provider, llm_provider_registry) register_rpc_endpoint_for_partial(delete_provider, llm_provider_registry) register_rpc_endpoint_for_partial( create_validator, validators_registry, accounts_manager diff --git a/tests/integration/test_llm_providers_registry.py b/tests/integration/test_llm_providers_registry.py new file mode 100644 index 000000000..ec89d64e8 --- /dev/null +++ b/tests/integration/test_llm_providers_registry.py @@ -0,0 +1,34 @@ +from tests.common.request import payload, post_request_localhost +from tests.common.response import has_success_status + + +def test_llm_providers(): + reset_result = post_request_localhost( + payload("reset_defaults_llm_providers") + ).json() + assert has_success_status(reset_result) + + response = post_request_localhost(payload("get_providers_and_models")).json() + assert has_success_status(response) + + default_providers = response["result"]["data"] + first_default_provider = default_providers[0] + last_provider_id = default_providers[-1]["id"] + + # Create a new provider + response = post_request_localhost( + payload("add_provider", first_default_provider) + ).json() + assert has_success_status(response) + + provider_id = response["result"]["data"] + + # + response = post_request_localhost( + payload("update_provider", last_provider_id, first_default_provider) + ).json() + assert has_success_status(response) + + # Delete it + response = post_request_localhost(payload("delete_provider", provider_id)).json() + assert has_success_status(response) From 0c0f4579ba53944fd9603f138fbc98f82b638187 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 19:25:34 -0300 Subject: [PATCH 34/75] fix todo MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/database_handler/alembic.ini | 2 +- backend/database_handler/llm_providers.py | 1 - backend/database_handler/migration/env.py | 5 ++++- .../versions/db38e78684a8_add_providers_table.py | 9 +++++++++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/backend/database_handler/alembic.ini b/backend/database_handler/alembic.ini index 5d70e020e..cecac4987 100644 --- a/backend/database_handler/alembic.ini +++ b/backend/database_handler/alembic.ini @@ -13,7 +13,7 @@ script_location = migration # sys.path path, will be prepended to sys.path if present. # defaults to the current working directory. -prepend_sys_path = . +prepend_sys_path = ./backend/database_handler # timezone to use when rendering the date within the migration file # as well as the filename. diff --git a/backend/database_handler/llm_providers.py b/backend/database_handler/llm_providers.py index ec01afbbe..2757b836c 100644 --- a/backend/database_handler/llm_providers.py +++ b/backend/database_handler/llm_providers.py @@ -8,7 +8,6 @@ class LLMProviderRegistry: def __init__(self, session: Session): self.session = session - # TODO: we should call this to fill up the database with the default providers def reset_defaults(self): """Reset all providers to their default values.""" self.session.query(LLMProviderDBModel).delete() diff --git a/backend/database_handler/migration/env.py b/backend/database_handler/migration/env.py index 9228ef1e8..78e066ea6 100644 --- a/backend/database_handler/migration/env.py +++ b/backend/database_handler/migration/env.py @@ -1,4 +1,5 @@ from logging.config import fileConfig +import sys from sqlalchemy import engine_from_config from sqlalchemy import pool @@ -6,6 +7,8 @@ from alembic import context import os +# set up Python path as the project root directory, so that we can import as backend... +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) DB_URL = os.environ.get("DB_URL") @@ -21,7 +24,7 @@ # add your model's MetaData object here # for 'autogenerate' support -from models import Base +from backend.database_handler.models import Base target_metadata = Base.metadata diff --git a/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py b/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py index e70946954..747b1403b 100644 --- a/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py +++ b/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py @@ -10,8 +10,11 @@ from alembic import op import sqlalchemy as sa +from sqlalchemy.orm import sessionmaker from sqlalchemy.dialects import postgresql +from backend.database_handler.llm_providers import LLMProviderRegistry + # revision identifiers, used by Alembic. revision: str = "db38e78684a8" down_revision: Union[str, None] = "d9ddc7436122" @@ -42,6 +45,12 @@ def upgrade() -> None: sa.PrimaryKeyConstraint("id", name="llm_provider_pkey"), ) # ### end Alembic commands ### + # Get the connection from Alembic + bind = op.get_bind() + + # Create a new SQLAlchemy session using the connection + with sessionmaker(bind=bind)() as session: + LLMProviderRegistry(session).reset_defaults() def downgrade() -> None: From 4ea314547630e8114270f60d0d93dfdadc35b58a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 6 Sep 2024 19:31:03 -0300 Subject: [PATCH 35/75] remove venv in pre-commit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 876269f34..f566f2e6e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: hooks: - id: backend-unit-pytest name: backend unit tests with pytest - entry: bash -c "source .venv/bin/activate && pytest tests/unit" + entry: pytest tests/unit language: system types: [python] pass_filenames: false From 5fd2727e64347eb977420c564607a6c1266cf2b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Mon, 9 Sep 2024 08:32:38 -0300 Subject: [PATCH 36/75] improve types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/genvm/equivalence_principle.py | 3 ++- backend/node/genvm/llms.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/backend/node/genvm/equivalence_principle.py b/backend/node/genvm/equivalence_principle.py index 5827496dc..36c3a954b 100644 --- a/backend/node/genvm/equivalence_principle.py +++ b/backend/node/genvm/equivalence_principle.py @@ -1,6 +1,7 @@ # backend/node/genvm/equivalence_principle.py from typing import Optional +from backend.node.genvm.base import ContractRunner from backend.node.genvm.context_wrapper import enforce_with_context from backend.node.genvm import llms from backend.node.genvm.webpage_utils import get_webpage_content @@ -22,7 +23,7 @@ def clear_locals(scope): @enforce_with_context class EquivalencePrinciple: - contract_runner: dict + contract_runner: ContractRunner def __init__( self, diff --git a/backend/node/genvm/llms.py b/backend/node/genvm/llms.py index 43c99faa6..cacd2e2ec 100644 --- a/backend/node/genvm/llms.py +++ b/backend/node/genvm/llms.py @@ -6,7 +6,8 @@ import aiohttp import asyncio from typing import Optional -from openai import OpenAI +from openai import OpenAI, Stream +from openai.types.chat import ChatCompletionChunk from dotenv import load_dotenv @@ -99,7 +100,7 @@ async def call_heuristai( return output -def get_openai_client(api_key: str, url: str = None): +def get_openai_client(api_key: str, url: str = None) -> OpenAI: openai_client = None if url: openai_client = OpenAI(api_key=api_key, base_url=url) @@ -108,7 +109,7 @@ def get_openai_client(api_key: str, url: str = None): return openai_client -def get_openai_stream(client, prompt, model_config): +def get_openai_stream(client: OpenAI, prompt, model_config): config = model_config["config"] if "temperature" in config and "max_tokens" in config: return client.chat.completions.create( @@ -126,7 +127,9 @@ def get_openai_stream(client, prompt, model_config): ) -async def get_openai_output(stream, regex, return_streaming_channel): +async def get_openai_output( + stream: Stream[ChatCompletionChunk], regex, return_streaming_channel +): buffer = "" for chunk in stream: chunk_str = chunk.choices[0].delta.content From 73ce60490e1e1a50ef94d2c864a9009210fe4059 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Mon, 9 Sep 2024 09:19:43 -0300 Subject: [PATCH 37/75] remove unnecessary import MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/protocol_rpc/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/protocol_rpc/requirements.txt b/backend/protocol_rpc/requirements.txt index 254f4f1a3..27f8db634 100644 --- a/backend/protocol_rpc/requirements.txt +++ b/backend/protocol_rpc/requirements.txt @@ -21,4 +21,3 @@ sentence-transformers==3.0.1 Flask-SQLAlchemy==3.1.1 jsf==0.11.2 jsonschema==4.23.0 -hypothesis_jsonschema==0.23.1 From a8c2cfdd105d1fa8c6766e22baa3f577a9e34b73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Mon, 9 Sep 2024 14:18:23 -0300 Subject: [PATCH 38/75] add plugins MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../heuristai_dolphin-2.9-llama3-8b.json | 1 + .../heuristai_meta-llamallama-2-70b-chat.json | 1 + ...ristai_mistralaimixtral-8x7b-instruct.json | 1 + .../heuristai_openhermes-2-yi-34b-gptq.json | 1 + .../default_providers/ollama_gemma.json | 1 + .../default_providers/ollama_llama3.json | 1 + .../default_providers/ollama_mistral.json | 1 + .../default_providers/openai_gpt-4.json | 1 + .../default_providers/openai_gpt-4o-mini.json | 1 + .../default_providers/openai_gpt-4o.json | 1 + .../node/create_nodes/providers_schema.json | 15 ++++++++++- backend/node/genvm/equivalence_principle.py | 4 +-- backend/node/genvm/llms.py | 25 ++++++++++++++++++- 13 files changed, 49 insertions(+), 5 deletions(-) diff --git a/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json index 57e6b536f..be085f7cb 100644 --- a/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json +++ b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json @@ -1,5 +1,6 @@ { "provider": "heuristai", + "plugin": "heuristai", "model": "dolphin-2.9-llama3-8b", "config": { "temperature": 0.75, diff --git a/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json index a405b094e..f3cf4b328 100644 --- a/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json +++ b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json @@ -1,5 +1,6 @@ { "provider": "heuristai", + "plugin": "heuristai", "model": "meta-llama/llama-2-70b-chat", "config": { "temperature": 0.75, diff --git a/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json index c65df3de8..11981b318 100644 --- a/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json +++ b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json @@ -1,5 +1,6 @@ { "provider": "heuristai", + "plugin": "heuristai", "model": "mistralai/mixtral-8x7b-instruct", "config": { "temperature": 0.75, diff --git a/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json index 9bf75a184..f4b001b2f 100644 --- a/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json +++ b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json @@ -1,5 +1,6 @@ { "provider": "heuristai", + "plugin": "heuristai", "model": "openhermes-2-yi-34b-gptq", "config": { "temperature": 0.75, diff --git a/backend/node/create_nodes/default_providers/ollama_gemma.json b/backend/node/create_nodes/default_providers/ollama_gemma.json index c6b5eb810..21c693ccd 100644 --- a/backend/node/create_nodes/default_providers/ollama_gemma.json +++ b/backend/node/create_nodes/default_providers/ollama_gemma.json @@ -1,5 +1,6 @@ { "provider": "ollama", + "plugin": "ollama", "model": "gemma", "config": { "mirostat": 0, diff --git a/backend/node/create_nodes/default_providers/ollama_llama3.json b/backend/node/create_nodes/default_providers/ollama_llama3.json index 1e59247a0..03bf2858b 100644 --- a/backend/node/create_nodes/default_providers/ollama_llama3.json +++ b/backend/node/create_nodes/default_providers/ollama_llama3.json @@ -1,5 +1,6 @@ { "provider": "ollama", + "plugin": "ollama", "model": "llama3", "config": { "mirostat": 0, diff --git a/backend/node/create_nodes/default_providers/ollama_mistral.json b/backend/node/create_nodes/default_providers/ollama_mistral.json index 7232c036d..36af67e5c 100644 --- a/backend/node/create_nodes/default_providers/ollama_mistral.json +++ b/backend/node/create_nodes/default_providers/ollama_mistral.json @@ -1,5 +1,6 @@ { "provider": "ollama", + "plugin": "ollama", "model": "mistral", "config": { "mirostat": 0, diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4.json b/backend/node/create_nodes/default_providers/openai_gpt-4.json index bb2ed93f1..ddb9a284b 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4.json @@ -1,5 +1,6 @@ { "provider": "openai", + "plugin": "openai", "model": "gpt-4", "config": "" } diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json b/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json index a5a387cd7..302bb8645 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json @@ -1,5 +1,6 @@ { "provider": "openai", + "plugin": "openai", "model": "gpt-4o-mini", "config": "" } diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4o.json b/backend/node/create_nodes/default_providers/openai_gpt-4o.json index d0f89403d..9be18865e 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4o.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4o.json @@ -1,5 +1,6 @@ { "provider": "openai", + "plugin": "openai", "model": "gpt-4o", "config": "" } diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 987980f79..bc7591f08 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -18,6 +18,10 @@ "const": "" } ] + }, + "plugin": { + "$comment": "plugin to be loaded by the simulator to interact with the provider", + "enum": ["heuristai", "openai", "ollama"] } }, "allOf": [ @@ -30,6 +34,9 @@ "then": { "properties": { + "plugin": { + "const": "ollama" + }, "model": { "enum": ["llama3", "mistral", "gemma"] }, @@ -163,6 +170,9 @@ }, "then": { "properties": { + "plugin": { + "const": "heuristai" + }, "model": { "enum": [ "mistralai/mixtral-8x7b-instruct", @@ -203,6 +213,9 @@ }, "then": { "properties": { + "plugin": { + "const": "openai" + }, "model": { "enum": ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4o-mini"] }, @@ -213,6 +226,6 @@ } } ], - "required": ["provider", "model", "config"], + "required": ["provider", "model", "config", "plugin"], "additionalProperties": false } diff --git a/backend/node/genvm/equivalence_principle.py b/backend/node/genvm/equivalence_principle.py index 36c3a954b..0c5848971 100644 --- a/backend/node/genvm/equivalence_principle.py +++ b/backend/node/genvm/equivalence_principle.py @@ -93,9 +93,7 @@ def set(self, value): self.contract_runner.eq_num += 1 def __get_llm_function(self): - function_name = "call_" + self.contract_runner.node_config["provider"] - llm_function = getattr(llms, function_name) - return llm_function + return llms.get_llm_function(self.contract_runner.node_config["plugin"]) async def call_llm_with_principle(prompt, eq_principle, comparative=True): diff --git a/backend/node/genvm/llms.py b/backend/node/genvm/llms.py index cacd2e2ec..d6755ca8c 100644 --- a/backend/node/genvm/llms.py +++ b/backend/node/genvm/llms.py @@ -1,4 +1,11 @@ -# backend/node/genvm/llms.py +""" +This file contains the plugins (functions) that are used to interact with the different LLMs (Language Model Models) that are used in the system. The plugins are registered in the `get_llm_function` function, which returns the function that corresponds to the plugin name. The plugins are called with the following parameters: + +- `model_config`: A dictionary containing the model and configuration to be used. +- `prompt`: The prompt to be sent to the LLM. +- `regex`: A regular expression to be used to stop the LLM. +- `return_streaming_channel`: An optional asyncio.Queue to stream the response. +""" import os import re @@ -152,3 +159,19 @@ async def get_openai_output( def get_ollama_url(endpoint: str) -> str: return f"{os.environ['OLAMAPROTOCOL']}://{os.environ['OLAMAHOST']}:{os.environ['OLAMAPORT']}/api/{endpoint}" + + +def get_llm_function(plugin: str): + """ + Function to register new providers + """ + plugin_to_function = { + "ollama": call_ollama, + "openai": call_openai, + "heuristai": call_heuristai, + } + + if plugin not in plugin_to_function: + raise ValueError(f"Plugin {plugin} not registered.") + + return plugin_to_function[plugin] From 3aa84d9b0e51bf8bca4a3514ad5461d2168da6ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Mon, 9 Sep 2024 14:42:45 -0300 Subject: [PATCH 39/75] improve plugins MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../heuristai_dolphin-2.9-llama3-8b.json | 4 +++ .../heuristai_meta-llamallama-2-70b-chat.json | 4 +++ ...ristai_mistralaimixtral-8x7b-instruct.json | 4 +++ .../heuristai_openhermes-2-yi-34b-gptq.json | 4 +++ .../default_providers/openai_gpt-4.json | 5 ++- .../default_providers/openai_gpt-4o-mini.json | 5 ++- .../default_providers/openai_gpt-4o.json | 5 ++- .../node/create_nodes/providers_schema.json | 35 +++++++++++++++++-- backend/node/genvm/llms.py | 11 +++--- 9 files changed, 68 insertions(+), 9 deletions(-) diff --git a/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json index be085f7cb..191222f40 100644 --- a/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json +++ b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json @@ -5,5 +5,9 @@ "config": { "temperature": 0.75, "max_tokens": 500 + }, + "plugin_config": { + "api_key_env_var": "HEURISTAIAPIKEY", + "url": "https://llm-gateway.heurist.xyz" } } diff --git a/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json index f3cf4b328..e8e16a1b8 100644 --- a/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json +++ b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json @@ -5,5 +5,9 @@ "config": { "temperature": 0.75, "max_tokens": 500 + }, + "plugin_config": { + "api_key_env_var": "HEURISTAIAPIKEY", + "url": "https://llm-gateway.heurist.xyz" } } diff --git a/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json index 11981b318..eaeb6f35e 100644 --- a/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json +++ b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json @@ -5,5 +5,9 @@ "config": { "temperature": 0.75, "max_tokens": 500 + }, + "plugin_config": { + "api_key_env_var": "HEURISTAIAPIKEY", + "url": "https://llm-gateway.heurist.xyz" } } diff --git a/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json index f4b001b2f..86b5a5609 100644 --- a/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json +++ b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json @@ -5,5 +5,9 @@ "config": { "temperature": 0.75, "max_tokens": 500 + }, + "plugin_config": { + "api_key_env_var": "HEURISTAIAPIKEY", + "url": "https://llm-gateway.heurist.xyz" } } diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4.json b/backend/node/create_nodes/default_providers/openai_gpt-4.json index ddb9a284b..2872fc8be 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4.json @@ -2,5 +2,8 @@ "provider": "openai", "plugin": "openai", "model": "gpt-4", - "config": "" + "config": "", + "plugin_config": { + "api_key_env_var": "OPENAIKEY" + } } diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json b/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json index 302bb8645..e6a9d71be 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json @@ -2,5 +2,8 @@ "provider": "openai", "plugin": "openai", "model": "gpt-4o-mini", - "config": "" + "config": "", + "plugin_config": { + "api_key_env_var": "OPENAIKEY" + } } diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4o.json b/backend/node/create_nodes/default_providers/openai_gpt-4o.json index 9be18865e..69b187198 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4o.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4o.json @@ -2,5 +2,8 @@ "provider": "openai", "plugin": "openai", "model": "gpt-4o", - "config": "" + "config": "", + "plugin_config": { + "api_key_env_var": "OPENAIKEY" + } } diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index bc7591f08..6bec67741 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -22,6 +22,9 @@ "plugin": { "$comment": "plugin to be loaded by the simulator to interact with the provider", "enum": ["heuristai", "openai", "ollama"] + }, + "plugin_config": { + "type": "object" } }, "allOf": [ @@ -173,6 +176,21 @@ "plugin": { "const": "heuristai" }, + "plugin_config": { + "type": "object", + "additionalProperties": false, + "properties": { + "api_key_env_var": { + "type": "string", + "$comment": "Environment variable that contains the API key" + }, + "url": { + "type": "string", + "$comment": "URL of the API endpoint" + } + }, + "required": ["api_key_env_var", "url"] + }, "model": { "enum": [ "mistralai/mixtral-8x7b-instruct", @@ -202,7 +220,8 @@ }, "required": ["temperature", "max_tokens"] } - } + }, + "required": ["plugin_config"] } }, { @@ -216,13 +235,25 @@ "plugin": { "const": "openai" }, + "plugin_config": { + "type": "object", + "additionalProperties": false, + "properties": { + "api_key_env_var": { + "type": "string", + "$comment": "Environment variable that contains the API key" + } + }, + "required": ["api_key_env_var"] + }, "model": { "enum": ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4o-mini"] }, "config": { "const": "" } - } + }, + "required": ["plugin_config"] } } ], diff --git a/backend/node/genvm/llms.py b/backend/node/genvm/llms.py index d6755ca8c..7386aaf82 100644 --- a/backend/node/genvm/llms.py +++ b/backend/node/genvm/llms.py @@ -20,6 +20,8 @@ load_dotenv() +plugin_config_key = "plugin_config" + async def process_streaming_buffer(buffer: str, chunk: str, regex: str) -> str: updated_buffer = buffer + chunk @@ -77,7 +79,8 @@ async def call_openai( regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], ) -> str: - client = get_openai_client(os.environ.get("OPENAIKEY")) + api_key_env_var = model_config[plugin_config_key]["api_key_env_var"] + client = get_openai_client(os.environ.get(api_key_env_var)) # TODO: OpenAI exceptions need to be caught here stream = get_openai_stream(client, prompt, model_config) @@ -90,9 +93,9 @@ async def call_heuristai( regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], ) -> str: - client = get_openai_client( - os.environ.get("HEURISTAIAPIKEY"), os.environ.get("HEURISTAIURL") - ) + api_key_env_var = model_config[plugin_config_key]["api_key_env_var"] + url = model_config[plugin_config_key]["url"] + client = get_openai_client(os.environ.get(api_key_env_var), os.environ.get(url)) stream = get_openai_stream(client, prompt, model_config) # TODO: Get the line below working # return await get_openai_output(stream, regex, return_streaming_channel) From 7e41f48ca4f42c02c06dd85bfd4eb87368c04392 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Tue, 10 Sep 2024 08:07:46 -0300 Subject: [PATCH 40/75] docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/providers.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index 28f8106c4..50d28339e 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -71,8 +71,10 @@ def create_random_providers(amount: int) -> list[LLMProvider]: return_value = [] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", HypothesisDeprecationWarning) + with warnings.catch_warnings(): # Catch warnings from hypothesis telling us to not use it for this purpose + warnings.simplefilter( + "ignore", HypothesisDeprecationWarning + ) # Disable the warning about using the deprecated `suppress_health_check` argument @settings( max_examples=amount, suppress_health_check=(HealthCheck.return_value,) @@ -83,10 +85,10 @@ def create_random_providers(amount: int) -> list[LLMProvider]: ), ) def inner(value): - nonlocal return_value + nonlocal return_value # Hypothesis `@given` wrapper doesn't allow us to return from the "test" function, so I'm using this closure to return the value provider = _to_domain(value) validate_provider(provider) return_value.append(provider) - inner() + inner() # Calling the function will fill the return_value list return return_value From 7c0c898e84d760b126ad1fdd65a8b2b93767c193 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Tue, 10 Sep 2024 10:13:37 -0300 Subject: [PATCH 41/75] add plugin config + fix docker config + minor extras MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/database_handler/llm_providers.py | 11 +- .../db38e78684a8_add_providers_table.py | 5 +- backend/database_handler/models.py | 1 + backend/domain/types.py | 12 +- .../default_providers/ollama_gemma.json | 3 + .../default_providers/ollama_llama3.json | 3 + .../default_providers/ollama_mistral.json | 3 + backend/node/create_nodes/providers.py | 15 ++ .../node/create_nodes/providers_schema.json | 23 ++- backend/protocol_rpc/endpoints.py | 25 ++-- docker-compose.yml | 16 +- docker/Dockerfile.backend | 1 - docker/Dockerfile.database-migration | 10 +- .../test_llm_providers_registry.py | 43 +++++- tests/unit/test_create_nodes.py | 138 ++++++++++++++---- 15 files changed, 243 insertions(+), 66 deletions(-) diff --git a/backend/database_handler/llm_providers.py b/backend/database_handler/llm_providers.py index 2757b836c..ba42c0390 100644 --- a/backend/database_handler/llm_providers.py +++ b/backend/database_handler/llm_providers.py @@ -26,12 +26,7 @@ def get_all(self) -> list[LLMProvider]: def get_all_dict(self) -> list[dict]: return [ - { - "id": provider.id, - "provider": provider.provider, - "model": provider.model, - "config": provider.config, - } + _to_domain(provider).__dict__ for provider in self.session.query(LLMProviderDBModel).all() ] @@ -49,6 +44,7 @@ def update(self, id: int, provider: LLMProvider): LLMProviderDBModel.provider: provider.provider, LLMProviderDBModel.model: provider.model, LLMProviderDBModel.config: provider.config, + LLMProviderDBModel.plugin_config: provider.plugin_config, } ) self.session.commit() @@ -62,9 +58,11 @@ def delete(self, id: int): def _to_domain(db_model: LLMProvider) -> LLMProvider: return LLMProvider( + id=db_model.id, provider=db_model.provider, model=db_model.model, config=db_model.config, + plugin_config=db_model.plugin_config, ) @@ -73,4 +71,5 @@ def _to_db_model(domain: LLMProvider) -> LLMProviderDBModel: provider=domain.provider, model=domain.model, config=domain.config, + plugin_config=domain.plugin_config, ) diff --git a/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py b/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py index 747b1403b..6091715de 100644 --- a/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py +++ b/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py @@ -17,7 +17,7 @@ # revision identifiers, used by Alembic. revision: str = "db38e78684a8" -down_revision: Union[str, None] = "d9ddc7436122" +down_revision: Union[str, None] = "f9636f013003" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -30,6 +30,9 @@ def upgrade() -> None: sa.Column("provider", sa.String(length=255), nullable=False), sa.Column("model", sa.String(length=255), nullable=False), sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column( + "plugin_config", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), sa.Column( "created_at", sa.DateTime(timezone=True), diff --git a/backend/database_handler/models.py b/backend/database_handler/models.py index 67c1452f6..b52096bfa 100644 --- a/backend/database_handler/models.py +++ b/backend/database_handler/models.py @@ -124,6 +124,7 @@ class LLMProviderDBModel(Base): provider: Mapped[str] = mapped_column(String(255)) model: Mapped[str] = mapped_column(String(255)) config: Mapped[dict | str] = mapped_column(JSONB) + plugin_config: Mapped[dict] = mapped_column(JSONB) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(True), server_default=func.current_timestamp(), init=False ) diff --git a/backend/domain/types.py b/backend/domain/types.py index 306092b0c..be1eb995e 100644 --- a/backend/domain/types.py +++ b/backend/domain/types.py @@ -10,9 +10,19 @@ class LLMProvider: provider: str model: str config: dict + plugin_config: dict + id: int | None = None def __hash__(self): - return hash((self.provider, self.model, frozenset(self.config.items()))) + return hash( + ( + self.id, + self.provider, + self.model, + frozenset(self.config.items()), + frozenset(self.plugin_config.items()), + ) + ) @dataclass() diff --git a/backend/node/create_nodes/default_providers/ollama_gemma.json b/backend/node/create_nodes/default_providers/ollama_gemma.json index 21c693ccd..06afeb047 100644 --- a/backend/node/create_nodes/default_providers/ollama_gemma.json +++ b/backend/node/create_nodes/default_providers/ollama_gemma.json @@ -19,5 +19,8 @@ "num_predict": 128, "top_k": 40, "top_p": 0.9 + }, + "plugin_config": { + "api_url": "http://ollama:11434/api/" } } diff --git a/backend/node/create_nodes/default_providers/ollama_llama3.json b/backend/node/create_nodes/default_providers/ollama_llama3.json index 03bf2858b..2adbc6dbe 100644 --- a/backend/node/create_nodes/default_providers/ollama_llama3.json +++ b/backend/node/create_nodes/default_providers/ollama_llama3.json @@ -19,5 +19,8 @@ "num_predict": 128, "top_k": 40, "top_p": 0.9 + }, + "plugin_config": { + "api_url": "http://ollama:11434/api/" } } diff --git a/backend/node/create_nodes/default_providers/ollama_mistral.json b/backend/node/create_nodes/default_providers/ollama_mistral.json index 36af67e5c..c43ecc02e 100644 --- a/backend/node/create_nodes/default_providers/ollama_mistral.json +++ b/backend/node/create_nodes/default_providers/ollama_mistral.json @@ -19,5 +19,8 @@ "num_predict": 128, "top_k": 40, "top_p": 0.9 + }, + "plugin_config": { + "api_url": "http://ollama:11434/api/" } } diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index 50d28339e..2853249e7 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -50,11 +50,26 @@ def get_default_providers() -> List[LLMProvider]: return providers +def get_default_provider_for(provider: str, model: str) -> LLMProvider: + llm_providers = get_default_providers() + matches = [ + llm_provider + for llm_provider in llm_providers + if llm_provider.provider == provider and llm_provider.model == model + ] + if not matches: + raise ValueError(f"No default provider found for {provider} and {model}") + if len(matches) > 1: + raise ValueError(f"Multiple default providers found for {provider} and {model}") + + def _to_domain(provider: dict) -> LLMProvider: return LLMProvider( + id=None, provider=provider["provider"], model=provider["model"], config=provider["config"], + plugin_config=provider["plugin_config"], ) diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 6bec67741..0411bd37d 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -4,7 +4,8 @@ "type": "object", "properties": { "provider": { - "enum": ["heuristai", "openai", "ollama"] + "type": "string", + "examples": ["ollama", "heuristai", "openai"] }, "model": { "type": "string" @@ -34,12 +35,22 @@ "provider": { "const": "ollama" } } }, - "then": { "properties": { "plugin": { "const": "ollama" }, + "plugin_config": { + "type": "object", + "additionalProperties": false, + "properties": { + "api_url": { + "type": "string", + "$comment": "Environment variable that contains the API key", + "default": "http://ollama:11434/api/" + } + } + }, "model": { "enum": ["llama3", "mistral", "gemma"] }, @@ -220,8 +231,7 @@ }, "required": ["temperature", "max_tokens"] } - }, - "required": ["plugin_config"] + } } }, { @@ -252,11 +262,10 @@ "config": { "const": "" } - }, - "required": ["plugin_config"] + } } } ], - "required": ["provider", "model", "config", "plugin"], + "required": ["provider", "model", "config", "plugin", "plugin_config"], "additionalProperties": false } diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index c16505bca..d123636fd 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -9,7 +9,11 @@ from backend.database_handler.llm_providers import LLMProviderRegistry from backend.database_handler.models import Base from backend.domain.types import LLMProvider -from backend.node.create_nodes.providers import validate_provider +from backend.node.create_nodes.providers import ( + get_default_provider_for, + get_default_providers, + validate_provider, +) from backend.protocol_rpc.configuration import GlobalConfiguration from backend.protocol_rpc.message_handler.base import MessageHandler from backend.database_handler.accounts_manager import AccountsManager @@ -163,16 +167,19 @@ def create_validator( stake: int, provider: str, model: str, - config: json, + config: dict | None, ) -> dict: - - validate_provider( - LLMProvider( - provider=provider, - model=model, - config=config, + # fallback for default provider + if not config: + config = get_default_provider_for(provider, model).config + else: + validate_provider( + LLMProvider( + provider=provider, + model=model, + config=config, + ) ) - ) new_address = accounts_manager.create_new_account().address return validators_registry.create_validator( diff --git a/docker-compose.yml b/docker-compose.yml index aba6171fd..6b9aea618 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -79,12 +79,24 @@ services: #volumes: # - "./data/postgres:/var/lib/postgresql/data" + build-backend-base: + build: + context: . + dockerfile: docker/Dockerfile.backend + target: base + image: backend-base + pull_policy: build + database-migration: build: - context: ./backend/database_handler - dockerfile: ../../docker/Dockerfile.database-migration + context: . + dockerfile: docker/Dockerfile.database-migration + args: + BASE_IMAGE: backend-base environment: - DB_URL=postgresql://${DBUSER}:${DBUSER}@postgres/${DBNAME} depends_on: + build-backend-base: + condition: service_completed_successfully postgres: condition: service_healthy diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend index 91b446a82..c2b6d2c60 100644 --- a/docker/Dockerfile.backend +++ b/docker/Dockerfile.backend @@ -14,7 +14,6 @@ RUN groupadd -r backend-group \ && chown -R backend-user:backend-group /home/backend-user \ && chown -R backend-user:backend-group $path -ENV PYTHONPATH "${PYTHONPATH}:/${path}" ENV FLASK_APP=backend/protocol_rpc/server.py ENV TRANSFORMERS_CACHE=/home/backend-user/.cache/huggingface diff --git a/docker/Dockerfile.database-migration b/docker/Dockerfile.database-migration index 828f9cc30..461f5ba63 100644 --- a/docker/Dockerfile.database-migration +++ b/docker/Dockerfile.database-migration @@ -1,11 +1,9 @@ -FROM python:3.12.5-slim +ARG BASE_IMAGE +FROM ${BASE_IMAGE} -WORKDIR /app +WORKDIR /app/backend/database_handler -COPY migration/requirements.txt . -RUN pip install --no-cache-dir -r requirements.txt - -COPY . . +RUN pip install -r migration/requirements.txt ENTRYPOINT [ "alembic" ] CMD [ "upgrade", "head" ] diff --git a/tests/integration/test_llm_providers_registry.py b/tests/integration/test_llm_providers_registry.py index ec89d64e8..45442fc13 100644 --- a/tests/integration/test_llm_providers_registry.py +++ b/tests/integration/test_llm_providers_registry.py @@ -3,6 +3,47 @@ def test_llm_providers(): + provider = { + "provider": "openai", + "model": "gpt-3.5-turbo", + "config": "", + "plugin_config": {"api_key_env_var": "OPENAIKEY"}, + } + # Create a new provider + response = post_request_localhost(payload("add_provider", provider)).json() + assert has_success_status(response) + + provider_id = response["result"]["data"] + + updated_provider = { + "provider": "openai", + "model": "gpt-4o", + "config": "", + "plugin_config": {"api_key_env_var": "OPENAIKEY"}, + } + # Uodate it + response = post_request_localhost( + payload("update_provider", provider_id, updated_provider) + ).json() + assert has_success_status(response) + + # Delete it + response = post_request_localhost(payload("delete_provider", provider_id)).json() + assert has_success_status(response) + + +def test_llm_providers_behavior(): + """ + Test the behavior of LLM providers endpoints by performing the following steps: + + 1. Reset the default LLM providers. + 2. Retrieve the list of providers and models. + 3. Extract the first default provider and the ID of the last provider. + 4. Add a new provider using the first default provider's data. + 5. Update the last provider using the first default provider's data. + 6. Delete the newly added provider. + + """ reset_result = post_request_localhost( payload("reset_defaults_llm_providers") ).json() @@ -23,7 +64,7 @@ def test_llm_providers(): provider_id = response["result"]["data"] - # + # Uodate it response = post_request_localhost( payload("update_provider", last_provider_id, first_default_provider) ).json() diff --git a/tests/unit/test_create_nodes.py b/tests/unit/test_create_nodes.py index 622028475..eec0c33d5 100644 --- a/tests/unit/test_create_nodes.py +++ b/tests/unit/test_create_nodes.py @@ -8,85 +8,151 @@ [ pytest.param( ["llama3"], - [LLMProvider(provider="ollama", model="llama3", config={})], + [ + LLMProvider( + provider="ollama", model="llama3", config={}, plugin_config={} + ) + ], None, None, 10, {}, - [LLMProvider(provider="ollama", model="llama3", config={})], + [ + LLMProvider( + provider="ollama", model="llama3", config={}, plugin_config={} + ) + ], id="only ollama", ), pytest.param( ["llama3", "llama3.1"], [ - LLMProvider(provider="ollama", model="llama3.1", config={}), - LLMProvider(provider="openai", model="gpt-4", config={}), - LLMProvider(provider="openai", model="gpt-4o", config={}), - LLMProvider(provider="heuristai", model="", config={}), + LLMProvider( + provider="ollama", model="llama3.1", config={}, plugin_config={} + ), + LLMProvider( + provider="openai", model="gpt-4", config={}, plugin_config={} + ), + LLMProvider( + provider="openai", model="gpt-4o", config={}, plugin_config={} + ), + LLMProvider( + provider="heuristai", model="", config={}, plugin_config={} + ), ], None, None, 10, {"OPENAI_API_KEY": ""}, - [LLMProvider(provider="ollama", model="llama3.1", config={})], + [ + LLMProvider( + provider="ollama", model="llama3.1", config={}, plugin_config={} + ) + ], id="only ollama available", ), pytest.param( ["llama3", "llama3.1"], [ - LLMProvider(provider="ollama", model="llama3", config={}), - LLMProvider(provider="openai", model="gpt-4", config={}), - LLMProvider(provider="openai", model="gpt-4o", config={}), - LLMProvider(provider="heuristai", model="", config={}), - LLMProvider(provider="heuristai", model="a", config={}), - LLMProvider(provider="heuristai", model="b", config={}), + LLMProvider( + provider="ollama", model="llama3", config={}, plugin_config={} + ), + LLMProvider( + provider="openai", model="gpt-4", config={}, plugin_config={} + ), + LLMProvider( + provider="openai", model="gpt-4o", config={}, plugin_config={} + ), + LLMProvider( + provider="heuristai", model="", config={}, plugin_config={} + ), + LLMProvider( + provider="heuristai", model="a", config={}, plugin_config={} + ), + LLMProvider( + provider="heuristai", model="b", config={}, plugin_config={} + ), ], ["openai"], None, 10, {"OPENAIKEY": "filled"}, [ - LLMProvider(provider="openai", model="gpt-4", config={}), - LLMProvider(provider="openai", model="gpt-4o", config={}), + LLMProvider( + provider="openai", model="gpt-4", config={}, plugin_config={} + ), + LLMProvider( + provider="openai", model="gpt-4o", config={}, plugin_config={} + ), ], id="only openai", ), pytest.param( ["llama3", "llama3.1"], [ - LLMProvider(provider="openai", model="gpt-4", config={}), - LLMProvider(provider="openai", model="gpt-4o", config={}), - LLMProvider(provider="heuristai", model="a", config={}), - LLMProvider(provider="heuristai", model="b", config={}), + LLMProvider( + provider="openai", model="gpt-4", config={}, plugin_config={} + ), + LLMProvider( + provider="openai", model="gpt-4o", config={}, plugin_config={} + ), + LLMProvider( + provider="heuristai", model="a", config={}, plugin_config={} + ), + LLMProvider( + provider="heuristai", model="b", config={}, plugin_config={} + ), ], ["heuristai"], ["a"], 10, {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, [ - LLMProvider(provider="heuristai", model="a", config={}), + LLMProvider( + provider="heuristai", model="a", config={}, plugin_config={} + ), ], id="only heuristai", ), pytest.param( ["llama3", "llama3.1"], [ - LLMProvider(provider="ollama", model="llama3.1", config={}), - LLMProvider(provider="openai", model="gpt-4", config={}), - LLMProvider(provider="openai", model="gpt-4o", config={}), - LLMProvider(provider="heuristai", model="a", config={}), - LLMProvider(provider="heuristai", model="b", config={}), + LLMProvider( + provider="ollama", model="llama3.1", config={}, plugin_config={} + ), + LLMProvider( + provider="openai", model="gpt-4", config={}, plugin_config={} + ), + LLMProvider( + provider="openai", model="gpt-4o", config={}, plugin_config={} + ), + LLMProvider( + provider="heuristai", model="a", config={}, plugin_config={} + ), + LLMProvider( + provider="heuristai", model="b", config={}, plugin_config={} + ), ], None, None, 10, {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, [ - LLMProvider(provider="ollama", model="llama3.1", config={}), - LLMProvider(provider="openai", model="gpt-4", config={}), - LLMProvider(provider="openai", model="gpt-4o", config={}), - LLMProvider(provider="heuristai", model="a", config={}), - LLMProvider(provider="heuristai", model="b", config={}), + LLMProvider( + provider="ollama", model="llama3.1", config={}, plugin_config={} + ), + LLMProvider( + provider="openai", model="gpt-4", config={}, plugin_config={} + ), + LLMProvider( + provider="openai", model="gpt-4o", config={}, plugin_config={} + ), + LLMProvider( + provider="heuristai", model="a", config={}, plugin_config={} + ), + LLMProvider( + provider="heuristai", model="b", config={}, plugin_config={} + ), ], id="all available", ), @@ -118,7 +184,11 @@ def test_random_validator_config( [ pytest.param( [], - [LLMProvider(provider="ollama", model="llama3", config={})], + [ + LLMProvider( + provider="ollama", model="llama3", config={}, plugin_config={} + ) + ], ["heuristai", "openai"], None, 10, @@ -128,7 +198,11 @@ def test_random_validator_config( ), pytest.param( [], - [LLMProvider(provider="ollama", model="llama3", config={})], + [ + LLMProvider( + provider="ollama", model="llama3", config={}, plugin_config={} + ) + ], ["ollama"], None, 10, From 2b03f1395c129f253720f4661e8465ebf70437d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Tue, 10 Sep 2024 10:25:49 -0300 Subject: [PATCH 42/75] fix dockerfile MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- docker/Dockerfile.backend | 1 + docker/Dockerfile.database-migration | 1 + 2 files changed, 2 insertions(+) diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend index c2b6d2c60..91b446a82 100644 --- a/docker/Dockerfile.backend +++ b/docker/Dockerfile.backend @@ -14,6 +14,7 @@ RUN groupadd -r backend-group \ && chown -R backend-user:backend-group /home/backend-user \ && chown -R backend-user:backend-group $path +ENV PYTHONPATH "${PYTHONPATH}:/${path}" ENV FLASK_APP=backend/protocol_rpc/server.py ENV TRANSFORMERS_CACHE=/home/backend-user/.cache/huggingface diff --git a/docker/Dockerfile.database-migration b/docker/Dockerfile.database-migration index 461f5ba63..8910019de 100644 --- a/docker/Dockerfile.database-migration +++ b/docker/Dockerfile.database-migration @@ -1,6 +1,7 @@ ARG BASE_IMAGE FROM ${BASE_IMAGE} +ENV PYTHONPATH "" WORKDIR /app/backend/database_handler RUN pip install -r migration/requirements.txt From f76ef900c8a361afc88dab72e0fcce50643aff69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Tue, 10 Sep 2024 10:30:37 -0300 Subject: [PATCH 43/75] hotfix cyclic error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/genvm/equivalence_principle.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/node/genvm/equivalence_principle.py b/backend/node/genvm/equivalence_principle.py index 0c5848971..fbd300d54 100644 --- a/backend/node/genvm/equivalence_principle.py +++ b/backend/node/genvm/equivalence_principle.py @@ -1,7 +1,6 @@ # backend/node/genvm/equivalence_principle.py from typing import Optional -from backend.node.genvm.base import ContractRunner from backend.node.genvm.context_wrapper import enforce_with_context from backend.node.genvm import llms from backend.node.genvm.webpage_utils import get_webpage_content @@ -23,7 +22,7 @@ def clear_locals(scope): @enforce_with_context class EquivalencePrinciple: - contract_runner: ContractRunner + contract_runner: any # TODO: this should be of type ContractRunner but that raises a cyclic import error def __init__( self, From f6d9a08076941410b3847e3d2ef06af6b20511ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Tue, 10 Sep 2024 15:12:21 -0300 Subject: [PATCH 44/75] start adding plugin configuration to validator, node and genvm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/database_handler/llm_providers.py | 3 + ...6b0dda_add_plugin_and_plugin_config_to_.py | 52 ++++++++++++ .../db38e78684a8_add_providers_table.py | 1 + backend/database_handler/models.py | 7 +- .../database_handler/validators_registry.py | 83 ++++++++++--------- backend/domain/types.py | 6 +- backend/node/base.py | 18 +--- backend/node/create_nodes/providers.py | 1 + backend/node/genvm/base.py | 7 +- backend/node/genvm/equivalence_principle.py | 2 +- backend/node/genvm/llms.py | 62 ++++++++++++-- backend/protocol_rpc/endpoints.py | 13 ++- tests/integration/__init__.py | 0 .../test_llm_providers_registry.py | 3 +- tests/integration/test_validators.py | 2 + 15 files changed, 188 insertions(+), 72 deletions(-) create mode 100644 backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py create mode 100644 tests/integration/__init__.py diff --git a/backend/database_handler/llm_providers.py b/backend/database_handler/llm_providers.py index ba42c0390..c613001fa 100644 --- a/backend/database_handler/llm_providers.py +++ b/backend/database_handler/llm_providers.py @@ -44,6 +44,7 @@ def update(self, id: int, provider: LLMProvider): LLMProviderDBModel.provider: provider.provider, LLMProviderDBModel.model: provider.model, LLMProviderDBModel.config: provider.config, + LLMProviderDBModel.plugin: provider.plugin, LLMProviderDBModel.plugin_config: provider.plugin_config, } ) @@ -62,6 +63,7 @@ def _to_domain(db_model: LLMProvider) -> LLMProvider: provider=db_model.provider, model=db_model.model, config=db_model.config, + plugin=db_model.plugin, plugin_config=db_model.plugin_config, ) @@ -71,5 +73,6 @@ def _to_db_model(domain: LLMProvider) -> LLMProviderDBModel: provider=domain.provider, model=domain.model, config=domain.config, + plugin=domain.plugin, plugin_config=domain.plugin_config, ) diff --git a/backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py b/backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py new file mode 100644 index 000000000..c57462634 --- /dev/null +++ b/backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py @@ -0,0 +1,52 @@ +"""add plugin and plugin_config to validators + +Revision ID: 986d9a6b0dda +Revises: db38e78684a8 +Create Date: 2024-09-10 14:47:10.730407 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "986d9a6b0dda" +down_revision: Union[str, None] = "db38e78684a8" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "validators", sa.Column("plugin", sa.String(length=255), nullable=False) + ) + op.add_column( + "validators", + sa.Column( + "plugin_config", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), + ) + op.alter_column( + "validators", "provider", existing_type=sa.VARCHAR(length=255), nullable=False + ) + op.alter_column( + "validators", "model", existing_type=sa.VARCHAR(length=255), nullable=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "validators", "model", existing_type=sa.VARCHAR(length=255), nullable=True + ) + op.alter_column( + "validators", "provider", existing_type=sa.VARCHAR(length=255), nullable=True + ) + op.drop_column("validators", "plugin_config") + op.drop_column("validators", "plugin") + # ### end Alembic commands ### diff --git a/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py b/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py index 6091715de..efc9bf54f 100644 --- a/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py +++ b/backend/database_handler/migration/versions/db38e78684a8_add_providers_table.py @@ -30,6 +30,7 @@ def upgrade() -> None: sa.Column("provider", sa.String(length=255), nullable=False), sa.Column("model", sa.String(length=255), nullable=False), sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("plugin", sa.String(length=255), nullable=False), sa.Column( "plugin_config", postgresql.JSONB(astext_type=sa.Text()), nullable=False ), diff --git a/backend/database_handler/models.py b/backend/database_handler/models.py index b52096bfa..dca6faa1c 100644 --- a/backend/database_handler/models.py +++ b/backend/database_handler/models.py @@ -109,8 +109,10 @@ class Validators(Base): stake: Mapped[int] = mapped_column(Integer) config: Mapped[dict] = mapped_column(JSONB) address: Mapped[Optional[str]] = mapped_column(String(255)) - provider: Mapped[Optional[str]] = mapped_column(String(255)) - model: Mapped[Optional[str]] = mapped_column(String(255)) + provider: Mapped[str] = mapped_column(String(255)) + model: Mapped[str] = mapped_column(String(255)) + plugin: Mapped[str] = mapped_column(String(255)) + plugin_config: Mapped[dict] = mapped_column(JSONB) created_at: Mapped[Optional[datetime.datetime]] = mapped_column( DateTime(True), server_default=func.current_timestamp(), init=False ) @@ -124,6 +126,7 @@ class LLMProviderDBModel(Base): provider: Mapped[str] = mapped_column(String(255)) model: Mapped[str] = mapped_column(String(255)) config: Mapped[dict | str] = mapped_column(JSONB) + plugin: Mapped[str] = mapped_column(String(255), nullable=False) plugin_config: Mapped[dict] = mapped_column(JSONB) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(True), server_default=func.current_timestamp(), init=False diff --git a/backend/database_handler/validators_registry.py b/backend/database_handler/validators_registry.py index 023418cee..0c9548551 100644 --- a/backend/database_handler/validators_registry.py +++ b/backend/database_handler/validators_registry.py @@ -2,6 +2,8 @@ from sqlalchemy.orm import Session +from backend.domain.types import LLMProvider, Validator + from .models import Validators from backend.errors.errors import ValidatorNotFound @@ -24,7 +26,7 @@ def __init__(self, session: Session): self.session = session self.db_validators_table = "validators" - def _get_validator_or_fail(self, validator_address: str): + def _get_validator_or_fail(self, validator_address: str) -> Validators: """Private method to check if an account exists, and raise an error if not.""" validator_data = ( @@ -35,9 +37,9 @@ def _get_validator_or_fail(self, validator_address: str): if validator_data is None: raise ValidatorNotFound(validator_address) - return to_dict(validator_data) + return validator_data - def count_validators(self): + def count_validators(self) -> int: return self.session.query(Validators).count() def get_all_validators(self) -> list: @@ -45,49 +47,24 @@ def get_all_validators(self) -> list: return [to_dict(validator) for validator in validators_data] def get_validator(self, validator_address: str) -> dict: - return self._get_validator_or_fail(validator_address) - - def create_validator( - self, - validator_address: str, - stake: int, - provider: str, - model: str, - config: dict, - ): - new_validator = Validators( - address=validator_address, - stake=stake, - provider=provider, - model=model, - config=config, - ) + return to_dict(self._get_validator_or_fail(validator_address)) - self.session.add(new_validator) + def create_validator(self, validator: Validator) -> dict: + self.session.add(_to_db_model(validator)) self.session.commit() - return self._get_validator_or_fail(validator_address) + return self.get_validator(validator.address) def update_validator( self, - validator_address: str, - stake: int, - provider: str, - model: str, - config: dict, - ): - self._get_validator_or_fail(validator_address) + new_validator: Validator, + ) -> dict: + validator = self._get_validator_or_fail(new_validator.address) - validator = ( - self.session.query(Validators) - .filter(Validators.address == validator_address) - .one() - ) - - validator.stake = stake - validator.provider = provider - validator.model = model - validator.config = config + validator.stake = new_validator.stake + validator.provider = new_validator.provider + validator.model = new_validator.model + validator.config = new_validator.config self.session.commit() @@ -104,3 +81,31 @@ def delete_validator(self, validator_address): def delete_all_validators(self): self.session.query(Validators).delete() self.session.commit() + + +def _to_domain(validator: Validators) -> Validator: + return Validator( + address=validator.address, + stake=validator.stake, + llmprovider=LLMProvider( + provider=validator.provider, + model=validator.model, + config=validator.config, + plugin=validator.plugin, + plugin_config=validator.plugin_config, + id=None, + ), + id=validator.id, + ) + + +def _to_db_model(validator: Validator) -> Validators: + return Validators( + address=validator.address, + stake=validator.stake, + provider=validator.llmprovider.provider, + model=validator.llmprovider.model, + config=validator.llmprovider.config, + plugin=validator.llmprovider.plugin, + plugin_config=validator.llmprovider.plugin_config, + ) diff --git a/backend/domain/types.py b/backend/domain/types.py index be1eb995e..c7ae0899e 100644 --- a/backend/domain/types.py +++ b/backend/domain/types.py @@ -10,6 +10,7 @@ class LLMProvider: provider: str model: str config: dict + plugin: str plugin_config: dict id: int | None = None @@ -29,6 +30,5 @@ def __hash__(self): class Validator: address: str stake: int - provider: str - model: str - config: dict + llmprovider: LLMProvider + id: int | None = None diff --git a/backend/node/base.py b/backend/node/base.py index a11713035..254d936e3 100644 --- a/backend/node/base.py +++ b/backend/node/base.py @@ -1,7 +1,7 @@ import json -import traceback from typing import Optional +from backend.domain.types import Validator from backend.node.genvm.base import GenVM from backend.database_handler.contract_snapshot import ContractSnapshot from backend.node.genvm.types import Receipt, ExecutionMode, Vote @@ -12,26 +12,16 @@ class Node: def __init__( self, contract_snapshot: ContractSnapshot, - address: str, validator_mode: ExecutionMode, - stake: int, - provider: str, - model: str, - config: dict, + validator: Validator, leader_receipt: Optional[Receipt] = None, msg_handler: MessageHandler = None, ): self.validator_mode = validator_mode - self.address = address - self.validator_info = { - "provider": provider, - "model": model, - "config": config, - "stake": stake, - } + self.address = validator.address self.leader_receipt = leader_receipt self.genvm = GenVM( - contract_snapshot, self.validator_mode, self.validator_info, msg_handler + contract_snapshot, self.validator_mode, validator.__dict__, msg_handler ) async def exec_transaction(self, transaction: dict): diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index 2853249e7..0b9e42346 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -69,6 +69,7 @@ def _to_domain(provider: dict) -> LLMProvider: provider=provider["provider"], model=provider["model"], config=provider["config"], + plugin=provider["plugin"], plugin_config=provider["plugin_config"], ) diff --git a/backend/node/genvm/base.py b/backend/node/genvm/base.py index 4fadd22e4..72ac2d53a 100644 --- a/backend/node/genvm/base.py +++ b/backend/node/genvm/base.py @@ -12,6 +12,7 @@ from contextlib import contextmanager, redirect_stdout from backend.database_handler.contract_snapshot import ContractSnapshot +from backend.domain.types import Validator from backend.node.genvm.equivalence_principle import EquivalencePrinciple from backend.node.genvm.code_enforcement import code_enforcement_check from backend.node.genvm.std.vector_store import VectorStore @@ -53,13 +54,13 @@ def __init__( self, snapshot: ContractSnapshot, validator_mode: str, - validator_info: dict, + validator: dict, msg_handler: MessageHandler = None, ): self.snapshot = snapshot self.validator_mode = validator_mode self.msg_handler = msg_handler - self.contract_runner = ContractRunner(validator_mode, validator_info) + self.contract_runner = ContractRunner(validator_mode, validator) @staticmethod def _get_contract_class_name(contract_code: str) -> str: @@ -85,7 +86,7 @@ def _generate_receipt( gas_used=self.contract_runner.gas_used, mode=self.contract_runner.mode, contract_state=encoded_object, - node_config=self.contract_runner.node_config, + node_config=self.contract_runner.validator.__dict__, eq_outputs=self.contract_runner.eq_outputs, execution_result=execution_result, error=error, diff --git a/backend/node/genvm/equivalence_principle.py b/backend/node/genvm/equivalence_principle.py index fbd300d54..f2286d740 100644 --- a/backend/node/genvm/equivalence_principle.py +++ b/backend/node/genvm/equivalence_principle.py @@ -92,7 +92,7 @@ def set(self, value): self.contract_runner.eq_num += 1 def __get_llm_function(self): - return llms.get_llm_function(self.contract_runner.node_config["plugin"]) + return llms.get_llm_function(self.contract_runner.node_config["plugin"]).call async def call_llm_with_principle(prompt, eq_principle, comparative=True): diff --git a/backend/node/genvm/llms.py b/backend/node/genvm/llms.py index 7386aaf82..0f8bfa65b 100644 --- a/backend/node/genvm/llms.py +++ b/backend/node/genvm/llms.py @@ -7,6 +7,7 @@ - `return_streaming_channel`: An optional asyncio.Queue to stream the response. """ +from abc import ABC, abstractmethod import os import re import json @@ -164,17 +165,64 @@ def get_ollama_url(endpoint: str) -> str: return f"{os.environ['OLAMAPROTOCOL']}://{os.environ['OLAMAHOST']}:{os.environ['OLAMAPORT']}/api/{endpoint}" -def get_llm_function(plugin: str): +class Plugin(ABC): + @abstractmethod + def call( + self, + model_config: dict, + prompt: str, + regex: Optional[str], + return_streaming_channel: Optional[asyncio.Queue], + ) -> str: + pass + + +class OllamaPlugin(Plugin): + async def call( + self, + model_config: dict, + prompt: str, + regex: Optional[str], + return_streaming_channel: Optional[asyncio.Queue], + ) -> str: + return await call_ollama(model_config, prompt, regex, return_streaming_channel) + + +class OpenAIPlugin(Plugin): + async def call( + self, + model_config: dict, + prompt: str, + regex: Optional[str], + return_streaming_channel: Optional[asyncio.Queue], + ) -> str: + return await call_openai(model_config, prompt, regex, return_streaming_channel) + + +class HeuristAIPlugin(Plugin): + async def call( + self, + model_config: dict, + prompt: str, + regex: Optional[str], + return_streaming_channel: Optional[asyncio.Queue], + ) -> str: + return await call_heuristai( + model_config, prompt, regex, return_streaming_channel + ) + + +def get_llm_function(plugin: str) -> Plugin: """ Function to register new providers """ - plugin_to_function = { - "ollama": call_ollama, - "openai": call_openai, - "heuristai": call_heuristai, + plugin_map = { + "ollama": OllamaPlugin(), + "openai": OpenAIPlugin(), + "heuristai": HeuristAIPlugin(), } - if plugin not in plugin_to_function: + if plugin not in plugin_map: raise ValueError(f"Plugin {plugin} not registered.") - return plugin_to_function[plugin] + return plugin_map[plugin] diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index d123636fd..88b14b0fb 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -138,6 +138,8 @@ def add_provider(llm_provider_registry: LLMProviderRegistry, params: dict) -> di provider=params["provider"], model=params["model"], config=params["config"], + plugin=params["plugin"], + plugin_config=["plugin_config"], ) validate_provider(provider) @@ -151,6 +153,8 @@ def update_provider( provider=params["provider"], model=params["model"], config=params["config"], + plugin=params["plugin"], + plugin_config=["plugin_config"], ) validate_provider(provider) @@ -168,16 +172,21 @@ def create_validator( provider: str, model: str, config: dict | None, + plugin: str | None, + plugin_config: dict | None, ) -> dict: # fallback for default provider - if not config: - config = get_default_provider_for(provider, model).config + if not (config and plugin and plugin_config): + provider = get_default_provider_for(provider, model) + config = provider.config else: validate_provider( LLMProvider( provider=provider, model=model, config=config, + plugin=plugin, + plugin_config=plugin_config, ) ) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/test_llm_providers_registry.py b/tests/integration/test_llm_providers_registry.py index 45442fc13..aac023374 100644 --- a/tests/integration/test_llm_providers_registry.py +++ b/tests/integration/test_llm_providers_registry.py @@ -53,7 +53,8 @@ def test_llm_providers_behavior(): assert has_success_status(response) default_providers = response["result"]["data"] - first_default_provider = default_providers[0] + first_default_provider: dict = default_providers[0] + del first_default_provider["id"] last_provider_id = default_providers[-1]["id"] # Create a new provider diff --git a/tests/integration/test_validators.py b/tests/integration/test_validators.py index eda4466bb..d3a004c5c 100644 --- a/tests/integration/test_validators.py +++ b/tests/integration/test_validators.py @@ -22,6 +22,8 @@ def test_validators(): validator["provider"], validator["model"], validator["config"], + validator["plugin"], + validator["plugin_config"], ) ).json() assert has_success_status(response) From fe391ad713b5bc2b4b39a7953e5d8718f2dbe508 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Tue, 10 Sep 2024 20:41:47 -0300 Subject: [PATCH 45/75] improve migration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- ...6b0dda_add_plugin_and_plugin_config_to_.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py b/backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py index c57462634..6715e8bb0 100644 --- a/backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py +++ b/backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py @@ -9,9 +9,13 @@ from typing import Sequence, Union from alembic import op +from sqlalchemy.orm import sessionmaker import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from backend.database_handler.models import Validators +from backend.node.create_nodes.providers import get_default_provider_for + # revision identifiers, used by Alembic. revision: str = "986d9a6b0dda" down_revision: Union[str, None] = "db38e78684a8" @@ -20,15 +24,31 @@ def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### + op.add_column("validators", sa.Column("plugin", sa.String(length=255))) op.add_column( - "validators", sa.Column("plugin", sa.String(length=255), nullable=False) + "validators", + sa.Column("plugin_config", postgresql.JSONB(astext_type=sa.Text())), ) - op.add_column( + + bind = op.get_bind() + # Create a new SQLAlchemy session using the connection + with sessionmaker(bind=bind)() as session: + validators = session.query(Validators).all() + for validator in validators: + default_provider = get_default_provider_for( + provider=validator.provider, model=validator.model + ) + validator.plugin = default_provider.plugin + validator.plugin_config = default_provider.plugin_config + + op.alter_column( + "validators", "plugin", existing_type=sa.VARCHAR(length=255), nullable=False + ) + op.alter_column( "validators", - sa.Column( - "plugin_config", postgresql.JSONB(astext_type=sa.Text()), nullable=False - ), + "plugin_config", + existing_type=postgresql.JSONB(astext_type=sa.Text()), + nullable=False, ) op.alter_column( "validators", "provider", existing_type=sa.VARCHAR(length=255), nullable=False @@ -36,11 +56,9 @@ def upgrade() -> None: op.alter_column( "validators", "model", existing_type=sa.VARCHAR(length=255), nullable=False ) - # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.alter_column( "validators", "model", existing_type=sa.VARCHAR(length=255), nullable=True ) @@ -49,4 +67,3 @@ def downgrade() -> None: ) op.drop_column("validators", "plugin_config") op.drop_column("validators", "plugin") - # ### end Alembic commands ### From 73b4de45147a506b72235fbc41b1537b697d4ad7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 09:52:06 -0300 Subject: [PATCH 46/75] fix migration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- ...6b0dda_add_plugin_and_plugin_config_to_.py | 44 ++++++++++++++----- backend/node/create_nodes/providers.py | 1 + 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py b/backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py index 6715e8bb0..ab20a8bde 100644 --- a/backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py +++ b/backend/database_handler/migration/versions/986d9a6b0dda_add_plugin_and_plugin_config_to_.py @@ -9,11 +9,10 @@ from typing import Sequence, Union from alembic import op -from sqlalchemy.orm import sessionmaker +from sqlalchemy import column, table import sqlalchemy as sa from sqlalchemy.dialects import postgresql -from backend.database_handler.models import Validators from backend.node.create_nodes.providers import get_default_provider_for # revision identifiers, used by Alembic. @@ -30,16 +29,39 @@ def upgrade() -> None: sa.Column("plugin_config", postgresql.JSONB(astext_type=sa.Text())), ) - bind = op.get_bind() - # Create a new SQLAlchemy session using the connection - with sessionmaker(bind=bind)() as session: - validators = session.query(Validators).all() - for validator in validators: - default_provider = get_default_provider_for( - provider=validator.provider, model=validator.model + # Modify below + + # Create a table object for the validators table + validators = table( + "validators", + column("id", sa.Integer), + column("provider", sa.String), + column("model", sa.String), + column("plugin", sa.String), + column("plugin_config", postgresql.JSONB), + column("config", postgresql.JSONB), + ) + + # Fetch existing data + conn = op.get_bind() + results = conn.execute(validators.select()) + + # Process data and perform updates + for validator in results: + id = validator.id + provider = validator.provider + model = validator.model + default_provider = get_default_provider_for(provider=provider, model=model) + conn.execute( + validators.update() + .where(validators.c.id == id) + .values( + plugin=default_provider.plugin, + plugin_config=default_provider.plugin_config, + config=default_provider.config, ) - validator.plugin = default_provider.plugin - validator.plugin_config = default_provider.plugin_config + ) + # Modify above op.alter_column( "validators", "plugin", existing_type=sa.VARCHAR(length=255), nullable=False diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index 0b9e42346..cda3feb7f 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -61,6 +61,7 @@ def get_default_provider_for(provider: str, model: str) -> LLMProvider: raise ValueError(f"No default provider found for {provider} and {model}") if len(matches) > 1: raise ValueError(f"Multiple default providers found for {provider} and {model}") + return matches[0] def _to_domain(provider: dict) -> LLMProvider: From 5746c2a3239e0ca9fcd154dfe6fa58bcf8058034 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 10:26:40 -0300 Subject: [PATCH 47/75] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../database_handler/validators_registry.py | 49 ++++++++--------- backend/protocol_rpc/endpoints.py | 54 +++++++++++-------- .../db-sqlalchemy/validators_registry_test.py | 27 ++++++++-- 3 files changed, 78 insertions(+), 52 deletions(-) diff --git a/backend/database_handler/validators_registry.py b/backend/database_handler/validators_registry.py index 0c9548551..1eeb45220 100644 --- a/backend/database_handler/validators_registry.py +++ b/backend/database_handler/validators_registry.py @@ -17,6 +17,8 @@ def to_dict(validator: Validators) -> dict: "provider": validator.provider, "model": validator.model, "config": validator.config, + "plugin": validator.plugin, + "plugin_config": validator.plugin_config, "created_at": validator.created_at.isoformat(), } @@ -51,7 +53,6 @@ def get_validator(self, validator_address: str) -> dict: def create_validator(self, validator: Validator) -> dict: self.session.add(_to_db_model(validator)) - self.session.commit() return self.get_validator(validator.address) @@ -62,41 +63,37 @@ def update_validator( validator = self._get_validator_or_fail(new_validator.address) validator.stake = new_validator.stake - validator.provider = new_validator.provider - validator.model = new_validator.model - validator.config = new_validator.config - - self.session.commit() + validator.provider = new_validator.llmprovider.provider + validator.model = new_validator.llmprovider.model + validator.config = new_validator.llmprovider.config + validator.plugin = new_validator.llmprovider.plugin + validator.plugin_config = new_validator.llmprovider.plugin_config return to_dict(validator) def delete_validator(self, validator_address): - self._get_validator_or_fail(validator_address) + validator = self._get_validator_or_fail(validator_address) - self.session.query(Validators).filter( - Validators.address == validator_address - ).delete() - self.session.commit() + self.session.delete(validator) def delete_all_validators(self): self.session.query(Validators).delete() - self.session.commit() -def _to_domain(validator: Validators) -> Validator: - return Validator( - address=validator.address, - stake=validator.stake, - llmprovider=LLMProvider( - provider=validator.provider, - model=validator.model, - config=validator.config, - plugin=validator.plugin, - plugin_config=validator.plugin_config, - id=None, - ), - id=validator.id, - ) +# def _to_domain(validator: Validators) -> Validator: +# return Validator( +# address=validator.address, +# stake=validator.stake, +# llmprovider=LLMProvider( +# provider=validator.provider, +# model=validator.model, +# config=validator.config, +# plugin=validator.plugin, +# plugin_config=validator.plugin_config, +# id=None, +# ), +# id=validator.id, +# ) def _to_db_model(validator: Validator) -> Validators: diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index 88b14b0fb..883f46658 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -8,7 +8,7 @@ from backend.database_handler.db_client import DBClient from backend.database_handler.llm_providers import LLMProviderRegistry from backend.database_handler.models import Base -from backend.domain.types import LLMProvider +from backend.domain.types import LLMProvider, Validator from backend.node.create_nodes.providers import ( get_default_provider_for, get_default_providers, @@ -176,23 +176,27 @@ def create_validator( plugin_config: dict | None, ) -> dict: # fallback for default provider + # TODO: only accept all or none of the config fields + llm_provider = None if not (config and plugin and plugin_config): - provider = get_default_provider_for(provider, model) - config = provider.config + llm_provider = get_default_provider_for(provider, model) else: - validate_provider( - LLMProvider( - provider=provider, - model=model, - config=config, - plugin=plugin, - plugin_config=plugin_config, - ) + llm_provider = LLMProvider( + provider=provider, + model=model, + config=config, + plugin=plugin, + plugin_config=plugin_config, ) + validate_provider(llm_provider) new_address = accounts_manager.create_new_account().address return validators_registry.create_validator( - new_address, stake, provider, model, config + Validator( + address=new_address, + stake=stake, + llmprovider=llm_provider, + ) ) @@ -260,21 +264,29 @@ def update_validator( stake: int, provider: str, model: str, - config: json, + config: dict, + plugin: str, + plugin_config: dict, ) -> dict: # Remove validation while adding migration to update the db address # if not accounts_manager.is_valid_address(validator_address): # raise InvalidAddressError(validator_address) - validate_provider( - LLMProvider( - provider=provider, - model=model, - config=config, - ) + llm_provider = LLMProvider( + provider=provider, + model=model, + config=config, + plugin=plugin, + plugin_config=plugin_config, + id=None, ) - return validators_registry.update_validator( - validator_address, stake, provider, model, config + validate_provider(llm_provider) + validator = Validator( + address=validator_address, + stake=stake, + llmprovider=llm_provider, + id=None, ) + return validators_registry.update_validator(validator) def delete_validator( diff --git a/tests/db-sqlalchemy/validators_registry_test.py b/tests/db-sqlalchemy/validators_registry_test.py index 1da7af743..446ccadf5 100644 --- a/tests/db-sqlalchemy/validators_registry_test.py +++ b/tests/db-sqlalchemy/validators_registry_test.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import Session from backend.database_handler.validators_registry import ValidatorsRegistry +from backend.domain.types import LLMProvider, Validator @pytest.fixture @@ -17,11 +18,24 @@ def test_validators_registry(validators_registry: ValidatorsRegistry): stake = 1 provider = "ollama" + plugin = "ollama" model = "llama3" config = {} - actual_validator = validators_registry.create_validator( - validator_address, stake, provider, model, config + plugin_config = {} + llm_provider = LLMProvider( + provider=provider, + model=model, + config=config, + plugin=plugin, + plugin_config=plugin_config, ) + validator = Validator( + address=validator_address, + stake=stake, + llmprovider=llm_provider, + ) + + actual_validator = validators_registry.create_validator(validator) assert validators_registry.count_validators() == 1 assert actual_validator["stake"] == stake @@ -43,9 +57,12 @@ def test_validators_registry(validators_registry: ValidatorsRegistry): new_model = "llama3.1" new_config = {"seed": 1, "key": {"array": [1, 2, 3]}} - actual_validator = validators_registry.update_validator( - validator_address, 2, new_provider, new_model, new_config - ) + validator.stake = new_stake + validator.llmprovider.provider = new_provider + validator.llmprovider.model = new_model + validator.llmprovider.config = new_config + + actual_validator = validators_registry.update_validator(validator) assert validators_registry.count_validators() == 1 From c9b7a6e06314be7dc00ca536406e496a64d0df24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 10:30:35 -0300 Subject: [PATCH 48/75] ci: add db integration tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../backend_integration_tests_pr.yml | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/.github/workflows/backend_integration_tests_pr.yml b/.github/workflows/backend_integration_tests_pr.yml index cfd09ed1f..76ed27186 100644 --- a/.github/workflows/backend_integration_tests_pr.yml +++ b/.github/workflows/backend_integration_tests_pr.yml @@ -108,3 +108,30 @@ jobs: - name: Shutdown Docker Compose if: always() run: docker compose down + + db-integration-test: + needs: triggers + if: ${{ needs.triggers.outputs.is_pull_request_opened == 'true' || needs.triggers.outputs.is_pull_request_review_approved == 'true' || needs.triggers.outputs.is_pull_request_labeled_with_run_tests == 'true' }} + + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Cache Docker layers + uses: actions/cache@v4 + with: + path: /tmp/.buildx-cache + key: ${{ runner.os }}-buildx-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-buildx- + + - name: Build Docker image + run: docker compose build + + - name: Run Docker Compose + run: docker compose -f tests/db-sqlalchemy/docker-compose.yml --project-directory . up tests --build --force-recreate --always-recreate-deps From c5209ac77bfc0c090ce8640fb7c7e4f5022c675d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 10:45:21 -0300 Subject: [PATCH 49/75] fix unit test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- tests/unit/test_create_nodes.py | 192 ++++++++++++++++++++++++++------ 1 file changed, 160 insertions(+), 32 deletions(-) diff --git a/tests/unit/test_create_nodes.py b/tests/unit/test_create_nodes.py index eec0c33d5..7394e8c08 100644 --- a/tests/unit/test_create_nodes.py +++ b/tests/unit/test_create_nodes.py @@ -10,7 +10,11 @@ ["llama3"], [ LLMProvider( - provider="ollama", model="llama3", config={}, plugin_config={} + provider="ollama", + model="llama3", + config={}, + plugin="", + plugin_config={}, ) ], None, @@ -19,7 +23,11 @@ {}, [ LLMProvider( - provider="ollama", model="llama3", config={}, plugin_config={} + provider="ollama", + model="llama3", + config={}, + plugin="", + plugin_config={}, ) ], id="only ollama", @@ -28,16 +36,32 @@ ["llama3", "llama3.1"], [ LLMProvider( - provider="ollama", model="llama3.1", config={}, plugin_config={} + provider="ollama", + model="llama3.1", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="openai", model="gpt-4", config={}, plugin_config={} + provider="openai", + model="gpt-4", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="openai", model="gpt-4o", config={}, plugin_config={} + provider="openai", + model="gpt-4o", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="heuristai", model="", config={}, plugin_config={} + provider="heuristai", + model="", + config={}, + plugin="", + plugin_config={}, ), ], None, @@ -46,7 +70,11 @@ {"OPENAI_API_KEY": ""}, [ LLMProvider( - provider="ollama", model="llama3.1", config={}, plugin_config={} + provider="ollama", + model="llama3.1", + config={}, + plugin="", + plugin_config={}, ) ], id="only ollama available", @@ -55,22 +83,46 @@ ["llama3", "llama3.1"], [ LLMProvider( - provider="ollama", model="llama3", config={}, plugin_config={} + provider="ollama", + model="llama3", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="openai", model="gpt-4", config={}, plugin_config={} + provider="openai", + model="gpt-4", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="openai", model="gpt-4o", config={}, plugin_config={} + provider="openai", + model="gpt-4o", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="heuristai", model="", config={}, plugin_config={} + provider="heuristai", + model="", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="heuristai", model="a", config={}, plugin_config={} + provider="heuristai", + model="a", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="heuristai", model="b", config={}, plugin_config={} + provider="heuristai", + model="b", + config={}, + plugin="", + plugin_config={}, ), ], ["openai"], @@ -79,10 +131,18 @@ {"OPENAIKEY": "filled"}, [ LLMProvider( - provider="openai", model="gpt-4", config={}, plugin_config={} + provider="openai", + model="gpt-4", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="openai", model="gpt-4o", config={}, plugin_config={} + provider="openai", + model="gpt-4o", + config={}, + plugin="", + plugin_config={}, ), ], id="only openai", @@ -91,16 +151,32 @@ ["llama3", "llama3.1"], [ LLMProvider( - provider="openai", model="gpt-4", config={}, plugin_config={} + provider="openai", + model="gpt-4", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="openai", model="gpt-4o", config={}, plugin_config={} + provider="openai", + model="gpt-4o", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="heuristai", model="a", config={}, plugin_config={} + provider="heuristai", + model="a", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="heuristai", model="b", config={}, plugin_config={} + provider="heuristai", + model="b", + config={}, + plugin="", + plugin_config={}, ), ], ["heuristai"], @@ -109,7 +185,11 @@ {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, [ LLMProvider( - provider="heuristai", model="a", config={}, plugin_config={} + provider="heuristai", + model="a", + config={}, + plugin="", + plugin_config={}, ), ], id="only heuristai", @@ -118,19 +198,39 @@ ["llama3", "llama3.1"], [ LLMProvider( - provider="ollama", model="llama3.1", config={}, plugin_config={} + provider="ollama", + model="llama3.1", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="openai", model="gpt-4", config={}, plugin_config={} + provider="openai", + model="gpt-4", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="openai", model="gpt-4o", config={}, plugin_config={} + provider="openai", + model="gpt-4o", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="heuristai", model="a", config={}, plugin_config={} + provider="heuristai", + model="a", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="heuristai", model="b", config={}, plugin_config={} + provider="heuristai", + model="b", + config={}, + plugin="", + plugin_config={}, ), ], None, @@ -139,19 +239,39 @@ {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, [ LLMProvider( - provider="ollama", model="llama3.1", config={}, plugin_config={} + provider="ollama", + model="llama3.1", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="openai", model="gpt-4", config={}, plugin_config={} + provider="openai", + model="gpt-4", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="openai", model="gpt-4o", config={}, plugin_config={} + provider="openai", + model="gpt-4o", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="heuristai", model="a", config={}, plugin_config={} + provider="heuristai", + model="a", + config={}, + plugin="", + plugin_config={}, ), LLMProvider( - provider="heuristai", model="b", config={}, plugin_config={} + provider="heuristai", + model="b", + config={}, + plugin="", + plugin_config={}, ), ], id="all available", @@ -186,7 +306,11 @@ def test_random_validator_config( [], [ LLMProvider( - provider="ollama", model="llama3", config={}, plugin_config={} + provider="ollama", + model="llama3", + config={}, + plugin="", + plugin_config={}, ) ], ["heuristai", "openai"], @@ -200,7 +324,11 @@ def test_random_validator_config( [], [ LLMProvider( - provider="ollama", model="llama3", config={}, plugin_config={} + provider="ollama", + model="llama3", + config={}, + plugin="", + plugin_config={}, ) ], ["ollama"], From 53c3361f9c4ea682e1144fce79e88278366cd09c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 10:54:05 -0300 Subject: [PATCH 50/75] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/providers.py | 7 ++++++- backend/protocol_rpc/endpoints.py | 4 ++-- tests/integration/test_llm_providers_registry.py | 4 +++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index cda3feb7f..46174d31b 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -20,9 +20,14 @@ def get_schema() -> dict: def validate_provider(provider: LLMProvider): + # Convert to JSON + provider_dict = provider.__dict__ + del provider_dict["id"] + + # Check against schema schema = get_schema() try: - validate(instance=provider.__dict__, schema=schema) + validate(instance=provider_dict, schema=schema) except Exception as e: raise ValueError(f"Error validating provider: {e}") diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index 883f46658..efe747da6 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -139,7 +139,7 @@ def add_provider(llm_provider_registry: LLMProviderRegistry, params: dict) -> di model=params["model"], config=params["config"], plugin=params["plugin"], - plugin_config=["plugin_config"], + plugin_config=params["plugin_config"], ) validate_provider(provider) @@ -154,7 +154,7 @@ def update_provider( model=params["model"], config=params["config"], plugin=params["plugin"], - plugin_config=["plugin_config"], + plugin_config=params["plugin_config"], ) validate_provider(provider) diff --git a/tests/integration/test_llm_providers_registry.py b/tests/integration/test_llm_providers_registry.py index aac023374..7d6ff26db 100644 --- a/tests/integration/test_llm_providers_registry.py +++ b/tests/integration/test_llm_providers_registry.py @@ -5,8 +5,9 @@ def test_llm_providers(): provider = { "provider": "openai", - "model": "gpt-3.5-turbo", + "model": "gpt-4", "config": "", + "plugin": "openai", "plugin_config": {"api_key_env_var": "OPENAIKEY"}, } # Create a new provider @@ -19,6 +20,7 @@ def test_llm_providers(): "provider": "openai", "model": "gpt-4o", "config": "", + "plugin": "openai", "plugin_config": {"api_key_env_var": "OPENAIKEY"}, } # Uodate it From 96f19837aa630eb354648d4e554c730a8c7951c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 11:39:08 -0300 Subject: [PATCH 51/75] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/database_handler/validators_registry.py | 1 - backend/protocol_rpc/endpoints.py | 6 +----- backend/protocol_rpc/server.py | 13 +++++++++++++ tests/db-sqlalchemy/validators_registry_test.py | 3 +++ 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/backend/database_handler/validators_registry.py b/backend/database_handler/validators_registry.py index 1eeb45220..e54e70cff 100644 --- a/backend/database_handler/validators_registry.py +++ b/backend/database_handler/validators_registry.py @@ -53,7 +53,6 @@ def get_validator(self, validator_address: str) -> dict: def create_validator(self, validator: Validator) -> dict: self.session.add(_to_db_model(validator)) - return self.get_validator(validator.address) def update_validator( diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index efe747da6..184006e9b 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -246,11 +246,7 @@ def create_random_validators( validator_address = accounts_manager.create_new_account().address validator = validators_registry.create_validator( - validator_address, - stake, - detail.provider, - detail.model, - detail.config, + Validator(address=validator_address, stake=stake, llmprovider=detail) ) response.append(validator) diff --git a/backend/protocol_rpc/server.py b/backend/protocol_rpc/server.py index 01312c2d9..f76c80737 100644 --- a/backend/protocol_rpc/server.py +++ b/backend/protocol_rpc/server.py @@ -63,6 +63,7 @@ def create_app(): validators_registry, consensus, llm_provider_registry, + sqlalchemy_db, ) @@ -78,6 +79,7 @@ def create_app(): validators_registry, consensus, llm_provider_registry, + sqlalchemy_db, ) = create_app() register_all_rpc_endpoints( jsonrpc, @@ -91,6 +93,17 @@ def create_app(): ) +# This method ensures that the transaction is committed or rolled back depending on the success of the request. +# Opinions on whether this is a good practice are divided https://github.com/pallets-eco/flask-sqlalchemy/issues/216 +@app.teardown_appcontext +def shutdown_session(exception=None): + if exception: + sqlalchemy_db.session.rollback() # Rollback if there is an exception + else: + sqlalchemy_db.session.commit() # Commit if everything is fine + sqlalchemy_db.session.remove() # Remove the session after every request + + def run_socketio(): socketio.run( app, diff --git a/tests/db-sqlalchemy/validators_registry_test.py b/tests/db-sqlalchemy/validators_registry_test.py index 446ccadf5..a5037fe23 100644 --- a/tests/db-sqlalchemy/validators_registry_test.py +++ b/tests/db-sqlalchemy/validators_registry_test.py @@ -35,6 +35,7 @@ def test_validators_registry(validators_registry: ValidatorsRegistry): llmprovider=llm_provider, ) + # Create actual_validator = validators_registry.create_validator(validator) assert validators_registry.count_validators() == 1 @@ -52,6 +53,7 @@ def test_validators_registry(validators_registry: ValidatorsRegistry): assert actual_validators == [actual_validator] + # Update new_stake = 2 new_provider = "ollama_new" new_model = "llama3.1" @@ -75,6 +77,7 @@ def test_validators_registry(validators_registry: ValidatorsRegistry): assert actual_validator["id"] == validator_id assert actual_validator["created_at"] == created_at + # Delete validators_registry.delete_validator(validator_address) assert len(validators_registry.get_all_validators()) == 0 From 924ba5e82dfa0b5fa11cb1a5838aac33a4ddc2dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 14:46:31 -0300 Subject: [PATCH 52/75] fix node MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../backend_integration_tests_pr.yml | 3 -- backend/consensus/base.py | 35 +++++++++----- backend/consensus/vrf.py | 4 +- .../database_handler/validators_registry.py | 3 +- backend/domain/types.py | 16 +++++++ backend/node/base.py | 6 ++- backend/node/create_nodes/providers.py | 1 + backend/node/genvm/base.py | 2 +- backend/node/genvm/equivalence_principle.py | 4 +- backend/node/genvm/std/vector_store.py | 15 +++--- backend/protocol_rpc/endpoints.py | 48 ++++++++++++------- .../mocks/call_contract_function.py | 4 +- tests/unit/test_types.py | 28 +++++++++++ 13 files changed, 123 insertions(+), 46 deletions(-) create mode 100644 tests/unit/test_types.py diff --git a/.github/workflows/backend_integration_tests_pr.yml b/.github/workflows/backend_integration_tests_pr.yml index 76ed27186..d96d7a317 100644 --- a/.github/workflows/backend_integration_tests_pr.yml +++ b/.github/workflows/backend_integration_tests_pr.yml @@ -130,8 +130,5 @@ jobs: restore-keys: | ${{ runner.os }}-buildx- - - name: Build Docker image - run: docker compose build - - name: Run Docker Compose run: docker compose -f tests/db-sqlalchemy/docker-compose.yml --project-directory . up tests --build --force-recreate --always-recreate-deps diff --git a/backend/consensus/base.py b/backend/consensus/base.py index ed4d6858e..456fd151f 100644 --- a/backend/consensus/base.py +++ b/backend/consensus/base.py @@ -17,6 +17,7 @@ ) from backend.database_handler.accounts_manager import AccountsManager from backend.database_handler.types import ConsensusData +from backend.domain.types import LLMProvider, Validator from backend.node.base import Node from backend.node.genvm.types import ExecutionMode, Vote from backend.protocol_rpc.message_handler.base import MessageHandler @@ -128,12 +129,18 @@ async def exec_transaction( # Create Leader leader_node = Node( contract_snapshot=contract_snapshot, - address=leader["address"], validator_mode=ExecutionMode.LEADER, - stake=leader["stake"], - provider=leader["provider"], - model=leader["model"], - config=leader["config"], + validator=Validator( + address=leader["address"], + stake=leader["stake"], + llmprovider=LLMProvider( + provider=leader["provider"], + model=leader["model"], + config=leader["config"], + plugin=leader["plugin"], + plugin_config=leader["plugin_config"], + ), + ), msg_handler=self.msg_handler, ) @@ -149,16 +156,22 @@ async def exec_transaction( validator_nodes = [ Node( contract_snapshot=contract_snapshot, - address=validator["address"], validator_mode=ExecutionMode.VALIDATOR, - stake=validator["stake"], - provider=validator["provider"], - model=validator["model"], - config=validator["config"], + validator=Validator( + address=validator["address"], + stake=validator["stake"], + llmprovider=LLMProvider( + provider=validator["provider"], + model=validator["model"], + config=validator["config"], + plugin=validator["plugin"], + plugin_config=validator["plugin_config"], + ), + ), leader_receipt=leader_receipt, msg_handler=self.msg_handler, ) - for i, validator in enumerate(remaining_validators) + for validator in remaining_validators ] # Validators execute transaction diff --git a/backend/consensus/vrf.py b/backend/consensus/vrf.py index 38d99cda1..228ff8c2c 100644 --- a/backend/consensus/vrf.py +++ b/backend/consensus/vrf.py @@ -20,7 +20,9 @@ def select_random_validators(all_validators: list, num_validators: int) -> list: return [all_validators[i] for i in unique_indices] -def get_validators_for_transaction(all_validators: list, num_validators: int) -> tuple: +def get_validators_for_transaction( + all_validators: list, num_validators: int +) -> tuple[dict, list]: selected_validators = select_random_validators(all_validators, num_validators) leader = selected_validators[0] remaining_validators = selected_validators[1 : num_validators + 1] diff --git a/backend/database_handler/validators_registry.py b/backend/database_handler/validators_registry.py index e54e70cff..d1c687d1b 100644 --- a/backend/database_handler/validators_registry.py +++ b/backend/database_handler/validators_registry.py @@ -1,5 +1,6 @@ # consensus/domain/state.py +from typing import List from sqlalchemy.orm import Session from backend.domain.types import LLMProvider, Validator @@ -44,7 +45,7 @@ def _get_validator_or_fail(self, validator_address: str) -> Validators: def count_validators(self) -> int: return self.session.query(Validators).count() - def get_all_validators(self) -> list: + def get_all_validators(self) -> List[dict]: validators_data = self.session.query(Validators).all() return [to_dict(validator) for validator in validators_data] diff --git a/backend/domain/types.py b/backend/domain/types.py index c7ae0899e..4533962f3 100644 --- a/backend/domain/types.py +++ b/backend/domain/types.py @@ -32,3 +32,19 @@ class Validator: stake: int llmprovider: LLMProvider id: int | None = None + + def to_dict(self): + result = { + "address": self.address, + "stake": self.stake, + "provider": self.llmprovider.provider, + "model": self.llmprovider.model, + "config": self.llmprovider.config, + "plugin": self.llmprovider.plugin, + "plugin_config": self.llmprovider.plugin_config, + } + + if self.id: + result["id"] = self.id + + return result diff --git a/backend/node/base.py b/backend/node/base.py index 254d936e3..4879ed99b 100644 --- a/backend/node/base.py +++ b/backend/node/base.py @@ -1,3 +1,4 @@ +from dataclasses import asdict import json from typing import Optional @@ -21,7 +22,10 @@ def __init__( self.address = validator.address self.leader_receipt = leader_receipt self.genvm = GenVM( - contract_snapshot, self.validator_mode, validator.__dict__, msg_handler + contract_snapshot, + self.validator_mode, + validator.to_dict(), + msg_handler, ) async def exec_transaction(self, transaction: dict): diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index 46174d31b..28018116b 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -80,6 +80,7 @@ def _to_domain(provider: dict) -> LLMProvider: ) +# TODO: We could merge part of this logic of getting the available providers by loading the plugins. The plugins could have methods like `is_available` and `get_available_models` that would simplify this logic. def create_random_providers(amount: int) -> list[LLMProvider]: """ Not being used at the moment, left here for future reference. diff --git a/backend/node/genvm/base.py b/backend/node/genvm/base.py index 72ac2d53a..851b3758b 100644 --- a/backend/node/genvm/base.py +++ b/backend/node/genvm/base.py @@ -86,7 +86,7 @@ def _generate_receipt( gas_used=self.contract_runner.gas_used, mode=self.contract_runner.mode, contract_state=encoded_object, - node_config=self.contract_runner.validator.__dict__, + node_config=self.contract_runner.node_config, eq_outputs=self.contract_runner.eq_outputs, execution_result=execution_result, error=error, diff --git a/backend/node/genvm/equivalence_principle.py b/backend/node/genvm/equivalence_principle.py index f2286d740..91806b47d 100644 --- a/backend/node/genvm/equivalence_principle.py +++ b/backend/node/genvm/equivalence_principle.py @@ -1,6 +1,6 @@ # backend/node/genvm/equivalence_principle.py -from typing import Optional +from typing import Any, Optional from backend.node.genvm.context_wrapper import enforce_with_context from backend.node.genvm import llms from backend.node.genvm.webpage_utils import get_webpage_content @@ -22,7 +22,7 @@ def clear_locals(scope): @enforce_with_context class EquivalencePrinciple: - contract_runner: any # TODO: this should be of type ContractRunner but that raises a cyclic import error + contract_runner: Any # TODO: this should be of type ContractRunner but that raises a cyclic import error def __init__( self, diff --git a/backend/node/genvm/std/vector_store.py b/backend/node/genvm/std/vector_store.py index 1c99829b4..438e0562a 100644 --- a/backend/node/genvm/std/vector_store.py +++ b/backend/node/genvm/std/vector_store.py @@ -1,5 +1,6 @@ # backend/node/genvm/std/vector_store.py +from typing import Any import numpy as np from backend.node.genvm.std.models import get_model @@ -18,13 +19,13 @@ def __init__(self, model_name: str = None): self.model_name = model_name self.next_id = 0 - def add_text(self, text: str, metadata: any): + def add_text(self, text: str, metadata: Any): """ Add a new text to the store with its metadata. Args: text (str): The text to be added. - metadata (any): The metadata. + metadata (Any): The metadata. """ model = get_model(self.model_name) @@ -38,7 +39,7 @@ def add_text(self, text: str, metadata: any): return vector_id - def get_closest_vector(self, text: str) -> tuple[float, int, str, any, list[float]]: + def get_closest_vector(self, text: str) -> tuple[float, int, str, Any, list[float]]: """ Get the closest vector to the given text along with the similarity percentage and metadata. @@ -58,7 +59,7 @@ def get_closest_vector(self, text: str) -> tuple[float, int, str, any, list[floa def get_k_closest_vectors( self, text: str, k: int = 5 - ) -> list[tuple[float, int, str, any, list[float]]]: + ) -> list[tuple[float, int, str, Any, list[float]]]: """ Get the closest k vectors to the given text along with the similarity percentages and metadata. @@ -103,7 +104,7 @@ def get_k_closest_vectors( ] return results - def update_text(self, vector_id: int, new_text: str, new_metadata: any): + def update_text(self, vector_id: int, new_text: str, new_metadata: Any): """ Update the text and metadata of an existing vector. @@ -135,7 +136,7 @@ def delete_vector(self, vector_id: int): else: raise ValueError("Vector ID does not exist") - def get_vector(self, vector_id: int) -> tuple[str, any, list[float]]: + def get_vector(self, vector_id: int) -> tuple[str, Any, list[float]]: """ Retrieve a vector and its metadata from the store. @@ -172,7 +173,7 @@ def cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float: norm_b = np.linalg.norm(b) return dot_product / (norm_a * norm_b) - def get_all_items(self) -> list[tuple[str, any]]: + def get_all_items(self) -> list[tuple[str, Any]]: """ Get all vectors and their metadata from the store. """ diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index 184006e9b..b528a67f6 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -93,14 +93,14 @@ def get_contract_schema( ) contract_account = accounts_manager.get_account_or_fail(contract_address) - node = Node( + node = Node( # Mock node just to get the data from the GenVM contract_snapshot=None, - address="", validator_mode=ExecutionMode.LEADER, - stake=0, - provider="", - model="", - config=None, + validator=Validator( + address="", + stake=0, + llmprovider=None, + ), leader_receipt=None, msg_handler=msg_handler, ) @@ -110,14 +110,20 @@ def get_contract_schema( def get_contract_schema_for_code( msg_handler: MessageHandler, contract_code: str ) -> dict: - node = Node( + node = Node( # Mock node just to get the data from the GenVM contract_snapshot=None, - address="", validator_mode=ExecutionMode.LEADER, - stake=0, - provider="", - model="", - config=None, + validator=Validator( + address="", + stake=0, + llmprovider=LLMProvider( + provider="", + model="", + config={}, + plugin="", + plugin_config={}, + ), + ), leader_receipt=None, msg_handler=msg_handler, ) @@ -346,14 +352,20 @@ def call( decoded_data = decode_method_call_data(input) contract_account = accounts_manager.get_account_or_fail(to_address) - node = Node( + node = Node( # Mock node just to get the data from the GenVM contract_snapshot=None, - address="", validator_mode=ExecutionMode.LEADER, - stake=0, - provider="", - model="", - config=None, + validator=Validator( + address="", + stake=0, + llmprovider=LLMProvider( + provider="", + model="", + config={}, + plugin="", + plugin_config={}, + ), + ), leader_receipt=None, msg_handler=msg_handler, ) diff --git a/tests/integration/contract_examples/mocks/call_contract_function.py b/tests/integration/contract_examples/mocks/call_contract_function.py index 8d264ea4d..b546eb17d 100644 --- a/tests/integration/contract_examples/mocks/call_contract_function.py +++ b/tests/integration/contract_examples/mocks/call_contract_function.py @@ -16,10 +16,12 @@ "method": str, "mode": str, "node_config": { - "config": dict, + "config": dict | str, "model": str, "provider": str, "stake": int, + "plugin": str, + "plugin_config": dict, }, "vote": str, }, diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py new file mode 100644 index 000000000..2f87cad13 --- /dev/null +++ b/tests/unit/test_types.py @@ -0,0 +1,28 @@ +from dataclasses import asdict +from backend.domain.types import LLMProvider, Validator + + +def test_validator_to_dict(): + validator = Validator( + address="0x1234", + stake=100, + llmprovider=LLMProvider( + provider="provider", + model="model", + config={"config": "config"}, + plugin="plugin", + plugin_config={"plugin_config": "plugin_config"}, + ), + ) + + result = validator.to_dict() + + assert result == { + "address": "0x1234", + "stake": 100, + "provider": "provider", + "model": "model", + "config": {"config": "config"}, + "plugin": "plugin", + "plugin_config": {"plugin_config": "plugin_config"}, + } From c9d757bec76d318f5e3bb330c705529565dde979 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 15:07:12 -0300 Subject: [PATCH 53/75] fixt tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .github/workflows/backend_integration_tests_pr.yml | 6 ++++-- docker-compose.yml | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/backend_integration_tests_pr.yml b/.github/workflows/backend_integration_tests_pr.yml index d96d7a317..9cb77601d 100644 --- a/.github/workflows/backend_integration_tests_pr.yml +++ b/.github/workflows/backend_integration_tests_pr.yml @@ -60,8 +60,10 @@ jobs: restore-keys: | ${{ runner.os }}-buildx- - - name: Build Docker image - run: docker compose build + - name: Build Docker images + run: | + docker compose build build-backend-base + docker compose build - name: Run Docker Compose run: docker compose up -d diff --git a/docker-compose.yml b/docker-compose.yml index 6b9aea618..934186d43 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -79,6 +79,7 @@ services: #volumes: # - "./data/postgres:/var/lib/postgresql/data" + # Used for caching build-backend-base: build: context: . From 1104d5ff24e96c8b0a0f66d492f3daee124faca1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 15:41:02 -0300 Subject: [PATCH 54/75] fix test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index 934186d43..debe2094c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -93,7 +93,7 @@ services: context: . dockerfile: docker/Dockerfile.database-migration args: - BASE_IMAGE: backend-base + BASE_IMAGE: backend-base:latest environment: - DB_URL=postgresql://${DBUSER}:${DBUSER}@postgres/${DBNAME} depends_on: From cf5586949bf8768ff900059a882d29df0507c6e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 16:20:50 -0300 Subject: [PATCH 55/75] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- docker-compose.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docker-compose.yml b/docker-compose.yml index debe2094c..e6f53ff63 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -94,6 +94,7 @@ services: dockerfile: docker/Dockerfile.database-migration args: BASE_IMAGE: backend-base:latest + pull_policy: never environment: - DB_URL=postgresql://${DBUSER}:${DBUSER}@postgres/${DBNAME} depends_on: From f04b1b2267529e3378b9b7f46a226a74826cad68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 16:27:45 -0300 Subject: [PATCH 56/75] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index e6f53ff63..9ce610f3b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -85,7 +85,7 @@ services: context: . dockerfile: docker/Dockerfile.backend target: base - image: backend-base + image: backend-base:latest pull_policy: build database-migration: From 610c081a8ccd79af4d1f77a12b575c9e2964efc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 16:53:45 -0300 Subject: [PATCH 57/75] improve plugins MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/create_nodes.py | 54 ++------ .../heuristai_dolphin-2.9-llama3-8b.json | 2 +- .../heuristai_meta-llamallama-2-70b-chat.json | 2 +- ...ristai_mistralaimixtral-8x7b-instruct.json | 2 +- .../heuristai_openhermes-2-yi-34b-gptq.json | 2 +- .../node/create_nodes/providers_schema.json | 4 +- backend/node/genvm/equivalence_principle.py | 5 +- backend/node/genvm/llms.py | 130 ++++++++++++++---- backend/protocol_rpc/configuration.py | 14 +- backend/protocol_rpc/endpoints.py | 1 - backend/protocol_rpc/server.py | 2 +- 11 files changed, 124 insertions(+), 94 deletions(-) diff --git a/backend/node/create_nodes/create_nodes.py b/backend/node/create_nodes/create_nodes.py index 91c97efe2..b75d1d201 100644 --- a/backend/node/create_nodes/create_nodes.py +++ b/backend/node/create_nodes/create_nodes.py @@ -1,37 +1,21 @@ -import os -import re import secrets from typing import Callable, List -import requests from numpy.random import default_rng from dotenv import load_dotenv from backend.domain.types import LLMProvider +from backend.node.genvm.llms import get_llm_plugin load_dotenv() rng = default_rng(secrets.randbits(128)) -empty_provider_key_regex = r"^(|)$" -provider_key_names_suffix = ["_API_KEY", "KEY", "APIKEY"] - - -def get_available_ollama_models(get_ollama_url: Callable[[str], str]) -> List[str]: - ollama_models_result = requests.get(get_ollama_url("tags")).json() - installed_ollama_models = [] - for ollama_model in ollama_models_result["models"]: - # "llama3:latest" => "llama3" - installed_ollama_models.append(ollama_model["name"].split(":")[0]) - return installed_ollama_models - def random_validator_config( - get_available_ollama_models: Callable[[], str], get_stored_providers: Callable[[], List[LLMProvider]], limit_providers: set[str] = None, limit_models: set[str] = None, amount: int = 1, - environ: dict[str, str] = os.environ, ) -> List[LLMProvider]: providers_to_use = get_stored_providers() @@ -52,38 +36,20 @@ def random_validator_config( f"Requested providers '{limit_providers}' do not match any stored providers. Please review your stored providers." ) - # Ollama is the only provider which is not OpenAI compliant, thus it gets its custom logic - # To add more non-OpenAI compliant providers, we'll need to add more custom logic here or refactor the provider's schema to allow general configurations - available_ollama_models = get_available_ollama_models() + def filter_by_available(provider: LLMProvider) -> bool: + # TODO: we should probably inject the `get_llm_plugin` function to be able to mock it in tests + plugin = get_llm_plugin(provider.plugin, provider.plugin_config) + if not plugin.is_available(): + return False - providers_to_use = [ - provider - for provider in providers_to_use - if provider.provider != "ollama" or provider.model in available_ollama_models - ] + if not plugin.is_model_available(provider.model): + return False - def filter_by_available_key(provider: LLMProvider) -> bool: - if provider.provider == "ollama": - return True - provider_key_names = [ - provider.provider.upper() + suffix for suffix in provider_key_names_suffix - ] - for provider_key_name in provider_key_names: - if not re.match( - empty_provider_key_regex, environ.get(provider_key_name, "") - ): - return True - - return False + return True - providers_to_use = list(filter(filter_by_available_key, providers_to_use)) + providers_to_use = list(filter(filter_by_available, providers_to_use)) if not providers_to_use: raise Exception("No providers avaliable.") - # heuristic_models_result = requests.get(os.environ['HEURISTAIMODELSURL']).json() - # heuristic_models = [] - # for entry in heuristic_models_result: - # heuristic_models.append(entry['name']) - return list(rng.choice(providers_to_use, amount)) diff --git a/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json index 191222f40..b18bead58 100644 --- a/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json +++ b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json @@ -8,6 +8,6 @@ }, "plugin_config": { "api_key_env_var": "HEURISTAIAPIKEY", - "url": "https://llm-gateway.heurist.xyz" + "api_url": "https://llm-gateway.heurist.xyz" } } diff --git a/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json index e8e16a1b8..15a514b26 100644 --- a/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json +++ b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json @@ -8,6 +8,6 @@ }, "plugin_config": { "api_key_env_var": "HEURISTAIAPIKEY", - "url": "https://llm-gateway.heurist.xyz" + "api_url": "https://llm-gateway.heurist.xyz" } } diff --git a/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json index eaeb6f35e..edfa965ea 100644 --- a/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json +++ b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json @@ -8,6 +8,6 @@ }, "plugin_config": { "api_key_env_var": "HEURISTAIAPIKEY", - "url": "https://llm-gateway.heurist.xyz" + "api_url": "https://llm-gateway.heurist.xyz" } } diff --git a/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json index 86b5a5609..a4c78975f 100644 --- a/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json +++ b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json @@ -8,6 +8,6 @@ }, "plugin_config": { "api_key_env_var": "HEURISTAIAPIKEY", - "url": "https://llm-gateway.heurist.xyz" + "api_url": "https://llm-gateway.heurist.xyz" } } diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 0411bd37d..52df77790 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -195,12 +195,12 @@ "type": "string", "$comment": "Environment variable that contains the API key" }, - "url": { + "api_url": { "type": "string", "$comment": "URL of the API endpoint" } }, - "required": ["api_key_env_var", "url"] + "required": ["api_key_env_var", "api_url"] }, "model": { "enum": [ diff --git a/backend/node/genvm/equivalence_principle.py b/backend/node/genvm/equivalence_principle.py index 91806b47d..898fa133c 100644 --- a/backend/node/genvm/equivalence_principle.py +++ b/backend/node/genvm/equivalence_principle.py @@ -92,7 +92,10 @@ def set(self, value): self.contract_runner.eq_num += 1 def __get_llm_function(self): - return llms.get_llm_function(self.contract_runner.node_config["plugin"]).call + return llms.get_llm_plugin( + self.contract_runner.node_config["plugin"], + self.contract_runner.node_config["plugin_config"], + ).call async def call_llm_with_principle(prompt, eq_principle, comparative=True): diff --git a/backend/node/genvm/llms.py b/backend/node/genvm/llms.py index 0f8bfa65b..7ca265988 100644 --- a/backend/node/genvm/llms.py +++ b/backend/node/genvm/llms.py @@ -1,7 +1,7 @@ """ This file contains the plugins (functions) that are used to interact with the different LLMs (Language Model Models) that are used in the system. The plugins are registered in the `get_llm_function` function, which returns the function that corresponds to the plugin name. The plugins are called with the following parameters: -- `model_config`: A dictionary containing the model and configuration to be used. +- `node_config`: A dictionary containing the model and configuration to be used. - `prompt`: The prompt to be sent to the LLM. - `regex`: A regular expression to be used to stop the LLM. - `return_streaming_channel`: An optional asyncio.Queue to stream the response. @@ -18,6 +18,7 @@ from openai.types.chat import ChatCompletionChunk from dotenv import load_dotenv +import requests load_dotenv() @@ -44,16 +45,16 @@ async def stream_http_response(url, data): async def call_ollama( - model_config: dict, + node_config: dict, prompt: str, regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], ) -> str: url = get_ollama_url("generate") - data = {"model": model_config["model"], "prompt": prompt} + data = {"model": node_config["model"], "prompt": prompt} - for name, value in model_config["config"].items(): + for name, value in node_config["config"].items(): data[name] = value buffer = "" @@ -75,29 +76,29 @@ async def call_ollama( async def call_openai( - model_config: dict, + node_config: dict, prompt: str, regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], ) -> str: - api_key_env_var = model_config[plugin_config_key]["api_key_env_var"] + api_key_env_var = node_config[plugin_config_key]["api_key_env_var"] client = get_openai_client(os.environ.get(api_key_env_var)) # TODO: OpenAI exceptions need to be caught here - stream = get_openai_stream(client, prompt, model_config) + stream = get_openai_stream(client, prompt, node_config) return await get_openai_output(stream, regex, return_streaming_channel) async def call_heuristai( - model_config: dict, + node_config: dict, prompt: str, regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], ) -> str: - api_key_env_var = model_config[plugin_config_key]["api_key_env_var"] - url = model_config[plugin_config_key]["url"] + api_key_env_var = node_config[plugin_config_key]["api_key_env_var"] + url = node_config[plugin_config_key]["api_url"] client = get_openai_client(os.environ.get(api_key_env_var), os.environ.get(url)) - stream = get_openai_stream(client, prompt, model_config) + stream = get_openai_stream(client, prompt, node_config) # TODO: Get the line below working # return await get_openai_output(stream, regex, return_streaming_channel) output = "" @@ -120,11 +121,11 @@ def get_openai_client(api_key: str, url: str = None) -> OpenAI: return openai_client -def get_openai_stream(client: OpenAI, prompt, model_config): - config = model_config["config"] +def get_openai_stream(client: OpenAI, prompt, node_config): + config = node_config["config"] if "temperature" in config and "max_tokens" in config: return client.chat.completions.create( - model=model_config["model"], + model=node_config["model"], messages=[{"role": "user", "content": prompt}], stream=True, temperature=config["temperature"], @@ -132,7 +133,7 @@ def get_openai_stream(client: OpenAI, prompt, model_config): ) else: return client.chat.completions.create( - model=model_config["model"], + model=node_config["model"], messages=[{"role": "user", "content": prompt}], stream=True, ) @@ -167,62 +168,135 @@ def get_ollama_url(endpoint: str) -> str: class Plugin(ABC): @abstractmethod - def call( + def __init__(self, plugin_config: dict): + pass + + @abstractmethod + async def call( self, - model_config: dict, + node_config: dict, prompt: str, regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], ) -> str: pass + @abstractmethod + def is_available(self) -> bool: + pass + + @abstractmethod + def is_model_available(self, model: str) -> bool: + pass + class OllamaPlugin(Plugin): + def __init__(self, plugin_config: dict): + self.url = plugin_config["api_url"] + async def call( self, - model_config: dict, + node_config: dict, prompt: str, regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], ) -> str: - return await call_ollama(model_config, prompt, regex, return_streaming_channel) + return await call_ollama(node_config, prompt, regex, return_streaming_channel) + + def is_available(self) -> bool: + try: + if requests.get(self.url).status_code == 404: + return True + except Exception: + pass + return False + + def is_model_available(self, model: str) -> bool: + endpoint = f"{self.url}/tags" + ollama_models_result = requests.get(endpoint).json() + installed_ollama_models = [] + for ollama_model in ollama_models_result["models"]: + installed_ollama_models.append(ollama_model["name"].split(":")[0]) + return model in installed_ollama_models class OpenAIPlugin(Plugin): + def __init__(self, plugin_config: dict): + self.api_key_env_var = plugin_config["api_key_env_var"] + async def call( self, - model_config: dict, + node_config: dict, prompt: str, regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], ) -> str: - return await call_openai(model_config, prompt, regex, return_streaming_channel) + return await call_openai(node_config, prompt, regex, return_streaming_channel) + + def is_available(self) -> bool: + env_var = os.environ.get(self.api_key_env_var) + + return ( + env_var != None + and env_var != "" + and env_var != "" + ) + + def is_model_available(self, model: str) -> bool: + """ + Model checks are done by the shema providers_schema.json + """ + return True class HeuristAIPlugin(Plugin): + def __init__(self, plugin_config: dict): + self.api_key_env_var = plugin_config["api_key_env_var"] + self.url = plugin_config["api_url"] + async def call( self, - model_config: dict, + node_config: dict, prompt: str, regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], ) -> str: return await call_heuristai( - model_config, prompt, regex, return_streaming_channel + node_config, prompt, regex, return_streaming_channel + ) + + def is_available(self) -> bool: + env_var = os.environ.get(self.api_key_env_var) + + return ( + env_var != None + and env_var != "" + and env_var != "" ) + def is_model_available(self, model: str) -> bool: + """ + Model checks are done by the shema providers_schema.json + """ + # heuristic_models_result = requests.get(os.environ['HEURISTAIMODELSURL']).json() + # heuristic_models = [] + # for entry in heuristic_models_result: + # heuristic_models.append(entry['name']) + + return True + -def get_llm_function(plugin: str) -> Plugin: +def get_llm_plugin(plugin: str, plugin_config: dict) -> Plugin: """ Function to register new providers """ plugin_map = { - "ollama": OllamaPlugin(), - "openai": OpenAIPlugin(), - "heuristai": HeuristAIPlugin(), + "ollama": OllamaPlugin, + "openai": OpenAIPlugin, + "heuristai": HeuristAIPlugin, } if plugin not in plugin_map: raise ValueError(f"Plugin {plugin} not registered.") - return plugin_map[plugin] + return plugin_map[plugin](plugin_config) diff --git a/backend/protocol_rpc/configuration.py b/backend/protocol_rpc/configuration.py index eed7e9cc3..e3d36c2a5 100644 --- a/backend/protocol_rpc/configuration.py +++ b/backend/protocol_rpc/configuration.py @@ -1,20 +1,8 @@ import os import json -import requests - class GlobalConfiguration: @staticmethod - def get_ollama_url(endpoint: str) -> str: - return f"{os.environ['OLAMAPROTOCOL']}://{os.environ['OLAMAHOST']}:{os.environ['OLAMAPORT']}/api/{endpoint}" - - def get_disabled_info_logs_endpoints(self) -> list: + def get_disabled_info_logs_endpoints() -> list: return json.loads(os.environ.get("DISABLE_INFO_LOGS_ENDPOINTS", "[]")) - - def get_available_ollama_models(self) -> list: - ollama_models_result = requests.get(self.get_ollama_url("tags")).json() - installed_ollama_models = [] - for ollama_model in ollama_models_result["models"]: - installed_ollama_models.append(ollama_model["name"].split(":")[0]) - return installed_ollama_models diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index b528a67f6..ad3212b39 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -239,7 +239,6 @@ def create_random_validators( limit_models = limit_models or [] details = random_validator_config( - config.get_available_ollama_models, llm_provider_registry.get_all, limit_providers=set(limit_providers), limit_models=set(limit_models), diff --git a/backend/protocol_rpc/server.py b/backend/protocol_rpc/server.py index f76c80737..ebd25e32c 100644 --- a/backend/protocol_rpc/server.py +++ b/backend/protocol_rpc/server.py @@ -93,7 +93,7 @@ def create_app(): ) -# This method ensures that the transaction is committed or rolled back depending on the success of the request. +# This ensures that the transaction is committed or rolled back depending on the success of the request. # Opinions on whether this is a good practice are divided https://github.com/pallets-eco/flask-sqlalchemy/issues/216 @app.teardown_appcontext def shutdown_session(exception=None): From 723d12679fbf5052099f10cbee3cf9432cf9722f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Wed, 11 Sep 2024 16:55:35 -0300 Subject: [PATCH 58/75] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .github/workflows/backend_integration_tests_pr.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/backend_integration_tests_pr.yml b/.github/workflows/backend_integration_tests_pr.yml index 9cb77601d..0ea395ed5 100644 --- a/.github/workflows/backend_integration_tests_pr.yml +++ b/.github/workflows/backend_integration_tests_pr.yml @@ -62,8 +62,7 @@ jobs: - name: Build Docker images run: | - docker compose build build-backend-base - docker compose build + docker compose build --with-dependencies - name: Run Docker Compose run: docker compose up -d From 801f45ab9bbe4db521367e1c6fd8ca0e4c5cf69f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 12 Sep 2024 08:16:59 -0300 Subject: [PATCH 59/75] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/create_nodes.py | 4 +- backend/protocol_rpc/endpoints.py | 3 +- tests/unit/test_create_nodes.py | 651 +++++++++++----------- 3 files changed, 341 insertions(+), 317 deletions(-) diff --git a/backend/node/create_nodes/create_nodes.py b/backend/node/create_nodes/create_nodes.py index b75d1d201..41720abc5 100644 --- a/backend/node/create_nodes/create_nodes.py +++ b/backend/node/create_nodes/create_nodes.py @@ -5,7 +5,7 @@ from dotenv import load_dotenv from backend.domain.types import LLMProvider -from backend.node.genvm.llms import get_llm_plugin +from backend.node.genvm.llms import Plugin load_dotenv() rng = default_rng(secrets.randbits(128)) @@ -13,6 +13,7 @@ def random_validator_config( get_stored_providers: Callable[[], List[LLMProvider]], + get_llm_plugin: Callable[[str, dict], Plugin], limit_providers: set[str] = None, limit_models: set[str] = None, amount: int = 1, @@ -37,7 +38,6 @@ def random_validator_config( ) def filter_by_available(provider: LLMProvider) -> bool: - # TODO: we should probably inject the `get_llm_plugin` function to be able to mock it in tests plugin = get_llm_plugin(provider.plugin, provider.plugin_config) if not plugin.is_available(): return False diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index ad3212b39..4097c1723 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -11,9 +11,9 @@ from backend.domain.types import LLMProvider, Validator from backend.node.create_nodes.providers import ( get_default_provider_for, - get_default_providers, validate_provider, ) +from backend.node.genvm.llms import get_llm_plugin from backend.protocol_rpc.configuration import GlobalConfiguration from backend.protocol_rpc.message_handler.base import MessageHandler from backend.database_handler.accounts_manager import AccountsManager @@ -240,6 +240,7 @@ def create_random_validators( details = random_validator_config( llm_provider_registry.get_all, + get_llm_plugin, limit_providers=set(limit_providers), limit_models=set(limit_models), amount=count, diff --git a/tests/unit/test_create_nodes.py b/tests/unit/test_create_nodes.py index 7394e8c08..cb24527ee 100644 --- a/tests/unit/test_create_nodes.py +++ b/tests/unit/test_create_nodes.py @@ -1,362 +1,385 @@ +import asyncio +from typing import Callable, List, Optional import pytest from backend.domain.types import LLMProvider from backend.node.create_nodes.create_nodes import random_validator_config +from backend.node.genvm.llms import Plugin + + +def plugin_mock(available: bool, available_models: List[str]) -> Plugin: + class PluginMock(Plugin): + def __init__(self, plugin_config: dict): + pass + + async def call( + self, + node_config: dict, + prompt: str, + regex: Optional[str], + return_streaming_channel: Optional[asyncio.Queue], + ) -> str: + pass + + def is_available(self): + return available + + def is_model_available(self, model: str) -> bool: + return model in available_models + + return PluginMock({}) @pytest.mark.parametrize( - "available_ollama_models,stored_providers,limit_providers,limit_models,amount,environ,expected", + "stored_providers,plugins,limit_providers,limit_models,amount,expected", [ pytest.param( - ["llama3"], [ LLMProvider( provider="ollama", model="llama3", config={}, - plugin="", + plugin="ollama", plugin_config={}, ) ], + {"ollama": plugin_mock(True, ["llama3"])}, None, None, 10, - {}, [ LLMProvider( provider="ollama", model="llama3", config={}, - plugin="", + plugin="ollama", plugin_config={}, ) ], id="only ollama", ), - pytest.param( - ["llama3", "llama3.1"], - [ - LLMProvider( - provider="ollama", - model="llama3.1", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="openai", - model="gpt-4", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="openai", - model="gpt-4o", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="heuristai", - model="", - config={}, - plugin="", - plugin_config={}, - ), - ], - None, - None, - 10, - {"OPENAI_API_KEY": ""}, - [ - LLMProvider( - provider="ollama", - model="llama3.1", - config={}, - plugin="", - plugin_config={}, - ) - ], - id="only ollama available", - ), - pytest.param( - ["llama3", "llama3.1"], - [ - LLMProvider( - provider="ollama", - model="llama3", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="openai", - model="gpt-4", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="openai", - model="gpt-4o", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="heuristai", - model="", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="heuristai", - model="a", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="heuristai", - model="b", - config={}, - plugin="", - plugin_config={}, - ), - ], - ["openai"], - None, - 10, - {"OPENAIKEY": "filled"}, - [ - LLMProvider( - provider="openai", - model="gpt-4", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="openai", - model="gpt-4o", - config={}, - plugin="", - plugin_config={}, - ), - ], - id="only openai", - ), - pytest.param( - ["llama3", "llama3.1"], - [ - LLMProvider( - provider="openai", - model="gpt-4", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="openai", - model="gpt-4o", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="heuristai", - model="a", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="heuristai", - model="b", - config={}, - plugin="", - plugin_config={}, - ), - ], - ["heuristai"], - ["a"], - 10, - {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, - [ - LLMProvider( - provider="heuristai", - model="a", - config={}, - plugin="", - plugin_config={}, - ), - ], - id="only heuristai", - ), - pytest.param( - ["llama3", "llama3.1"], - [ - LLMProvider( - provider="ollama", - model="llama3.1", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="openai", - model="gpt-4", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="openai", - model="gpt-4o", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="heuristai", - model="a", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="heuristai", - model="b", - config={}, - plugin="", - plugin_config={}, - ), - ], - None, - None, - 10, - {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, - [ - LLMProvider( - provider="ollama", - model="llama3.1", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="openai", - model="gpt-4", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="openai", - model="gpt-4o", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="heuristai", - model="a", - config={}, - plugin="", - plugin_config={}, - ), - LLMProvider( - provider="heuristai", - model="b", - config={}, - plugin="", - plugin_config={}, - ), - ], - id="all available", - ), + # pytest.param( + # ["llama3", "llama3.1"], + # [ + # LLMProvider( + # provider="ollama", + # model="llama3.1", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="openai", + # model="gpt-4", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="openai", + # model="gpt-4o", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="heuristai", + # model="", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # ], + # None, + # None, + # 10, + # {"OPENAI_API_KEY": ""}, + # [ + # LLMProvider( + # provider="ollama", + # model="llama3.1", + # config={}, + # plugin="", + # plugin_config={}, + # ) + # ], + # id="only ollama available", + # ), + # pytest.param( + # ["llama3", "llama3.1"], + # [ + # LLMProvider( + # provider="ollama", + # model="llama3", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="openai", + # model="gpt-4", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="openai", + # model="gpt-4o", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="heuristai", + # model="", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="heuristai", + # model="a", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="heuristai", + # model="b", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # ], + # ["openai"], + # None, + # 10, + # {"OPENAIKEY": "filled"}, + # [ + # LLMProvider( + # provider="openai", + # model="gpt-4", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="openai", + # model="gpt-4o", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # ], + # id="only openai", + # ), + # pytest.param( + # ["llama3", "llama3.1"], + # [ + # LLMProvider( + # provider="openai", + # model="gpt-4", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="openai", + # model="gpt-4o", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="heuristai", + # model="a", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="heuristai", + # model="b", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # ], + # ["heuristai"], + # ["a"], + # 10, + # {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, + # [ + # LLMProvider( + # provider="heuristai", + # model="a", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # ], + # id="only heuristai", + # ), + # pytest.param( + # ["llama3", "llama3.1"], + # [ + # LLMProvider( + # provider="ollama", + # model="llama3.1", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="openai", + # model="gpt-4", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="openai", + # model="gpt-4o", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="heuristai", + # model="a", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="heuristai", + # model="b", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # ], + # None, + # None, + # 10, + # {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, + # [ + # LLMProvider( + # provider="ollama", + # model="llama3.1", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="openai", + # model="gpt-4", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="openai", + # model="gpt-4o", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="heuristai", + # model="a", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # LLMProvider( + # provider="heuristai", + # model="b", + # config={}, + # plugin="", + # plugin_config={}, + # ), + # ], + # id="all available", + # ), ], ) def test_random_validator_config( - available_ollama_models, stored_providers, + plugins, limit_providers, limit_models, amount, - environ, expected, ): result = random_validator_config( - lambda: available_ollama_models, lambda: stored_providers, + lambda plugin, config: plugins[plugin], limit_providers, limit_models, amount, - environ, ) assert set(result).issubset(set(expected)) -@pytest.mark.parametrize( - "available_ollama_models,stored_providers,limit_providers,limit_models,amount,environ,exception", - [ - pytest.param( - [], - [ - LLMProvider( - provider="ollama", - model="llama3", - config={}, - plugin="", - plugin_config={}, - ) - ], - ["heuristai", "openai"], - None, - 10, - {}, - ValueError, - id="no match", - ), - pytest.param( - [], - [ - LLMProvider( - provider="ollama", - model="llama3", - config={}, - plugin="", - plugin_config={}, - ) - ], - ["ollama"], - None, - 10, - {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, - Exception, - id="no intersection", - ), - ], -) -def test_random_validator_config_fail( - available_ollama_models, - stored_providers, - limit_providers, - limit_models, - amount, - environ, - exception, -): - with pytest.raises(exception): - random_validator_config( - result=random_validator_config( - lambda: available_ollama_models, - lambda: stored_providers, - limit_providers, - limit_models, - amount, - environ, - ) - ) +# @pytest.mark.parametrize( +# "available_ollama_models,stored_providers,limit_providers,limit_models,amount,environ,exception", +# [ +# pytest.param( +# [], +# [ +# LLMProvider( +# provider="ollama", +# model="llama3", +# config={}, +# plugin="", +# plugin_config={}, +# ) +# ], +# ["heuristai", "openai"], +# None, +# 10, +# {}, +# ValueError, +# id="no match", +# ), +# pytest.param( +# [], +# [ +# LLMProvider( +# provider="ollama", +# model="llama3", +# config={}, +# plugin="", +# plugin_config={}, +# ) +# ], +# ["ollama"], +# None, +# 10, +# {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, +# Exception, +# id="no intersection", +# ), +# ], +# ) +# def test_random_validator_config_fail( +# available_ollama_models, +# stored_providers, +# limit_providers, +# limit_models, +# amount, +# environ, +# exception, +# ): +# with pytest.raises(exception): +# random_validator_config( +# result=random_validator_config( +# lambda: available_ollama_models, +# lambda: stored_providers, +# limit_providers, +# limit_models, +# amount, +# environ, +# ) +# ) From a2c74d2bdd4137aafa78d93d635490f266ab0576 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 12 Sep 2024 08:24:24 -0300 Subject: [PATCH 60/75] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- tests/unit/test_create_nodes.py | 626 ++++++++++++++++---------------- 1 file changed, 316 insertions(+), 310 deletions(-) diff --git a/tests/unit/test_create_nodes.py b/tests/unit/test_create_nodes.py index cb24527ee..a182466cb 100644 --- a/tests/unit/test_create_nodes.py +++ b/tests/unit/test_create_nodes.py @@ -30,7 +30,7 @@ def is_model_available(self, model: str) -> bool: @pytest.mark.parametrize( - "stored_providers,plugins,limit_providers,limit_models,amount,expected", + "stored_providers,plugins,limit_providers,limit_models,expected", [ pytest.param( [ @@ -45,7 +45,6 @@ def is_model_available(self, model: str) -> bool: {"ollama": plugin_mock(True, ["llama3"])}, None, None, - 10, [ LLMProvider( provider="ollama", @@ -57,250 +56,258 @@ def is_model_available(self, model: str) -> bool: ], id="only ollama", ), - # pytest.param( - # ["llama3", "llama3.1"], - # [ - # LLMProvider( - # provider="ollama", - # model="llama3.1", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="openai", - # model="gpt-4", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="openai", - # model="gpt-4o", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="heuristai", - # model="", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # ], - # None, - # None, - # 10, - # {"OPENAI_API_KEY": ""}, - # [ - # LLMProvider( - # provider="ollama", - # model="llama3.1", - # config={}, - # plugin="", - # plugin_config={}, - # ) - # ], - # id="only ollama available", - # ), - # pytest.param( - # ["llama3", "llama3.1"], - # [ - # LLMProvider( - # provider="ollama", - # model="llama3", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="openai", - # model="gpt-4", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="openai", - # model="gpt-4o", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="heuristai", - # model="", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="heuristai", - # model="a", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="heuristai", - # model="b", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # ], - # ["openai"], - # None, - # 10, - # {"OPENAIKEY": "filled"}, - # [ - # LLMProvider( - # provider="openai", - # model="gpt-4", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="openai", - # model="gpt-4o", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # ], - # id="only openai", - # ), - # pytest.param( - # ["llama3", "llama3.1"], - # [ - # LLMProvider( - # provider="openai", - # model="gpt-4", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="openai", - # model="gpt-4o", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="heuristai", - # model="a", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="heuristai", - # model="b", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # ], - # ["heuristai"], - # ["a"], - # 10, - # {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, - # [ - # LLMProvider( - # provider="heuristai", - # model="a", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # ], - # id="only heuristai", - # ), - # pytest.param( - # ["llama3", "llama3.1"], - # [ - # LLMProvider( - # provider="ollama", - # model="llama3.1", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="openai", - # model="gpt-4", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="openai", - # model="gpt-4o", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="heuristai", - # model="a", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="heuristai", - # model="b", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # ], - # None, - # None, - # 10, - # {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, - # [ - # LLMProvider( - # provider="ollama", - # model="llama3.1", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="openai", - # model="gpt-4", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="openai", - # model="gpt-4o", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="heuristai", - # model="a", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # LLMProvider( - # provider="heuristai", - # model="b", - # config={}, - # plugin="", - # plugin_config={}, - # ), - # ], - # id="all available", - # ), + pytest.param( + [ + LLMProvider( + provider="ollama", + model="llama3.1", + config={}, + plugin="ollama", + plugin_config={}, + ), + LLMProvider( + provider="openai", + model="gpt-4", + config={}, + plugin="openai", + plugin_config={}, + ), + LLMProvider( + provider="openai", + model="gpt-4o", + config={}, + plugin="openai", + plugin_config={}, + ), + LLMProvider( + provider="heuristai", + model="heuristai", + config={}, + plugin="heuristai", + plugin_config={}, + ), + ], + { + "ollama": plugin_mock(True, ["llama3", "llama3.1"]), + "openai": plugin_mock(False, []), + "heuristai": plugin_mock(True, ["other"]), + }, + None, + None, + [ + LLMProvider( + provider="ollama", + model="llama3.1", + config={}, + plugin="ollama", + plugin_config={}, + ) + ], + id="only ollama available", + ), + pytest.param( + [ + LLMProvider( + provider="ollama", + model="llama3", + config={}, + plugin="ollama", + plugin_config={}, + ), + LLMProvider( + provider="openai", + model="gpt-4", + config={}, + plugin="openai", + plugin_config={}, + ), + LLMProvider( + provider="openai", + model="gpt-4o", + config={}, + plugin="openai", + plugin_config={}, + ), + LLMProvider( + provider="heuristai", + model="a", + config={}, + plugin="heuristai", + plugin_config={}, + ), + LLMProvider( + provider="heuristai", + model="b", + config={}, + plugin="heuristai", + plugin_config={}, + ), + LLMProvider( + provider="heuristai", + model="c", + config={}, + plugin="heuristai", + plugin_config={}, + ), + ], + { + "ollama": plugin_mock(True, ["llama3", "llama3.1"]), + "openai": plugin_mock(True, ["gpt-4", "gpt-4o"]), + "heuristai": plugin_mock(True, ["other"]), + }, + ["openai"], + None, + [ + LLMProvider( + provider="openai", + model="gpt-4", + config={}, + plugin="openai", + plugin_config={}, + ), + LLMProvider( + provider="openai", + model="gpt-4o", + config={}, + plugin="openai", + plugin_config={}, + ), + ], + id="only openai", + ), + pytest.param( + [ + LLMProvider( + provider="openai", + model="gpt-4", + config={}, + plugin="openai", + plugin_config={}, + ), + LLMProvider( + provider="openai", + model="gpt-4o", + config={}, + plugin="openai", + plugin_config={}, + ), + LLMProvider( + provider="heuristai", + model="a", + config={}, + plugin="heuristai", + plugin_config={}, + ), + LLMProvider( + provider="heuristai", + model="b", + config={}, + plugin="heuristai", + plugin_config={}, + ), + ], + { + "ollama": plugin_mock(False, ["llama3", "llama3.1"]), + "openai": plugin_mock(False, ["gpt-4", "gpt-4o"]), + "heuristai": plugin_mock(True, ["a", "b"]), + }, + ["heuristai"], + ["a"], + [ + LLMProvider( + provider="heuristai", + model="a", + config={}, + plugin="heuristai", + plugin_config={}, + ), + ], + id="only heuristai", + ), + pytest.param( + [ + LLMProvider( + provider="ollama", + model="llama3.1", + config={}, + plugin="ollama", + plugin_config={}, + ), + LLMProvider( + provider="openai", + model="gpt-4", + config={}, + plugin="openai", + plugin_config={}, + ), + LLMProvider( + provider="openai", + model="gpt-4o", + config={}, + plugin="openai", + plugin_config={}, + ), + LLMProvider( + provider="heuristai", + model="a", + config={}, + plugin="heuristai", + plugin_config={}, + ), + LLMProvider( + provider="heuristai", + model="b", + config={}, + plugin="heuristai", + plugin_config={}, + ), + ], + { + "ollama": plugin_mock(True, ["llama3", "llama3.1"]), + "openai": plugin_mock(True, ["gpt-4", "gpt-4o"]), + "heuristai": plugin_mock(True, ["a", "b"]), + }, + None, + None, + [ + LLMProvider( + provider="ollama", + model="llama3.1", + config={}, + plugin="ollama", + plugin_config={}, + ), + LLMProvider( + provider="openai", + model="gpt-4", + config={}, + plugin="openai", + plugin_config={}, + ), + LLMProvider( + provider="openai", + model="gpt-4o", + config={}, + plugin="openai", + plugin_config={}, + ), + LLMProvider( + provider="heuristai", + model="a", + config={}, + plugin="heuristai", + plugin_config={}, + ), + LLMProvider( + provider="heuristai", + model="b", + config={}, + plugin="heuristai", + plugin_config={}, + ), + ], + id="all available", + ), ], ) def test_random_validator_config( @@ -308,7 +315,6 @@ def test_random_validator_config( plugins, limit_providers, limit_models, - amount, expected, ): result = random_validator_config( @@ -316,70 +322,70 @@ def test_random_validator_config( lambda plugin, config: plugins[plugin], limit_providers, limit_models, - amount, + 10, ) - assert set(result).issubset(set(expected)) + result_set = set(result) + expected_set = set(expected) + + assert result_set.issubset(expected_set) -# @pytest.mark.parametrize( -# "available_ollama_models,stored_providers,limit_providers,limit_models,amount,environ,exception", -# [ -# pytest.param( -# [], -# [ -# LLMProvider( -# provider="ollama", -# model="llama3", -# config={}, -# plugin="", -# plugin_config={}, -# ) -# ], -# ["heuristai", "openai"], -# None, -# 10, -# {}, -# ValueError, -# id="no match", -# ), -# pytest.param( -# [], -# [ -# LLMProvider( -# provider="ollama", -# model="llama3", -# config={}, -# plugin="", -# plugin_config={}, -# ) -# ], -# ["ollama"], -# None, -# 10, -# {"OPENAI_API_KEY": "filled", "HEURISTAI_API_KEY": "filled"}, -# Exception, -# id="no intersection", -# ), -# ], -# ) -# def test_random_validator_config_fail( -# available_ollama_models, -# stored_providers, -# limit_providers, -# limit_models, -# amount, -# environ, -# exception, -# ): -# with pytest.raises(exception): -# random_validator_config( -# result=random_validator_config( -# lambda: available_ollama_models, -# lambda: stored_providers, -# limit_providers, -# limit_models, -# amount, -# environ, -# ) -# ) +@pytest.mark.parametrize( + "stored_providers,plugins,limit_providers,limit_models,exception", + [ + pytest.param( + [ + LLMProvider( + provider="ollama", + model="llama3", + config={}, + plugin="", + plugin_config={}, + ) + ], + {}, + ["heuristai", "openai"], + None, + ValueError, + id="no match", + ), + pytest.param( + [ + LLMProvider( + provider="ollama", + model="llama3", + config={}, + plugin="", + plugin_config={}, + ) + ], + { + "ollama": plugin_mock(False, ["llama3", "llama3.1"]), + "openai": plugin_mock(True, ["gpt-4", "gpt-4o"]), + "heuristai": plugin_mock(True, ["a", "b"]), + }, + ["ollama"], + None, + Exception, + id="no intersection", + ), + ], +) +def test_random_validator_config_fail( + stored_providers, + plugins, + limit_providers, + limit_models, + exception, +): + with pytest.raises(exception): + random_validator_config( + result=random_validator_config( + lambda: stored_providers, + lambda plugin, config: plugins[plugin], + limit_providers, + limit_models, + 10, + ) + ) From f778e7c7321a67e63ad221336893d8ae19623f80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 12 Sep 2024 08:46:01 -0300 Subject: [PATCH 61/75] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .github/workflows/backend_integration_tests_pr.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/backend_integration_tests_pr.yml b/.github/workflows/backend_integration_tests_pr.yml index 0ea395ed5..2d0c30c76 100644 --- a/.github/workflows/backend_integration_tests_pr.yml +++ b/.github/workflows/backend_integration_tests_pr.yml @@ -62,6 +62,7 @@ jobs: - name: Build Docker images run: | + docker compose build build-backend-base docker compose build --with-dependencies - name: Run Docker Compose From 7ba405893273bf806538236613b0bd7c78705ccf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 12 Sep 2024 08:56:22 -0300 Subject: [PATCH 62/75] test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .github/workflows/backend_integration_tests_pr.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/backend_integration_tests_pr.yml b/.github/workflows/backend_integration_tests_pr.yml index 2d0c30c76..d7f6a1805 100644 --- a/.github/workflows/backend_integration_tests_pr.yml +++ b/.github/workflows/backend_integration_tests_pr.yml @@ -60,10 +60,11 @@ jobs: restore-keys: | ${{ runner.os }}-buildx- + - name: Build Backend base image + run: docker compose build build-backend-base + - name: Build Docker images - run: | - docker compose build build-backend-base - docker compose build --with-dependencies + run: docker compose build --with-dependencies - name: Run Docker Compose run: docker compose up -d From 329ec437e83a241573f398616c370038b681567f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 12 Sep 2024 09:32:51 -0300 Subject: [PATCH 63/75] test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .github/workflows/backend_integration_tests_pr.yml | 6 ------ docker-compose.yml | 1 - 2 files changed, 7 deletions(-) diff --git a/.github/workflows/backend_integration_tests_pr.yml b/.github/workflows/backend_integration_tests_pr.yml index d7f6a1805..9a5de547e 100644 --- a/.github/workflows/backend_integration_tests_pr.yml +++ b/.github/workflows/backend_integration_tests_pr.yml @@ -60,12 +60,6 @@ jobs: restore-keys: | ${{ runner.os }}-buildx- - - name: Build Backend base image - run: docker compose build build-backend-base - - - name: Build Docker images - run: docker compose build --with-dependencies - - name: Run Docker Compose run: docker compose up -d diff --git a/docker-compose.yml b/docker-compose.yml index 9ce610f3b..d358aeeec 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -57,7 +57,6 @@ services: volumes: - ./.ollama:/root/.ollama container_name: ollama - pull_policy: always tty: true restart: always From 59f8f88d138ae52e439ed69459fd8533cc270bb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 12 Sep 2024 10:24:06 -0300 Subject: [PATCH 64/75] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../backend_integration_tests_pr.yml | 3 +++ docker-compose.yml | 14 -------------- docker/Dockerfile.backend | 4 ++-- docker/Dockerfile.database-migration | 19 +++++++++++++++++-- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/.github/workflows/backend_integration_tests_pr.yml b/.github/workflows/backend_integration_tests_pr.yml index 9a5de547e..7dcbe6a30 100644 --- a/.github/workflows/backend_integration_tests_pr.yml +++ b/.github/workflows/backend_integration_tests_pr.yml @@ -60,6 +60,9 @@ jobs: restore-keys: | ${{ runner.os }}-buildx- + - name: Build Docker images + run: docker compose build + - name: Run Docker Compose run: docker compose up -d diff --git a/docker-compose.yml b/docker-compose.yml index d358aeeec..cada9ff55 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -78,26 +78,12 @@ services: #volumes: # - "./data/postgres:/var/lib/postgresql/data" - # Used for caching - build-backend-base: - build: - context: . - dockerfile: docker/Dockerfile.backend - target: base - image: backend-base:latest - pull_policy: build - database-migration: build: context: . dockerfile: docker/Dockerfile.database-migration - args: - BASE_IMAGE: backend-base:latest - pull_policy: never environment: - DB_URL=postgresql://${DBUSER}:${DBUSER}@postgres/${DBNAME} depends_on: - build-backend-base: - condition: service_completed_successfully postgres: condition: service_healthy diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend index 91b446a82..b865c0641 100644 --- a/docker/Dockerfile.backend +++ b/docker/Dockerfile.backend @@ -15,8 +15,8 @@ RUN groupadd -r backend-group \ && chown -R backend-user:backend-group $path ENV PYTHONPATH "${PYTHONPATH}:/${path}" -ENV FLASK_APP=backend/protocol_rpc/server.py -ENV TRANSFORMERS_CACHE=/home/backend-user/.cache/huggingface +ENV FLASK_APP backend/protocol_rpc/server.py +ENV TRANSFORMERS_CACHE /home/backend-user/.cache/huggingface COPY ../.env . COPY backend $path/backend diff --git a/docker/Dockerfile.database-migration b/docker/Dockerfile.database-migration index 8910019de..6db06a92b 100644 --- a/docker/Dockerfile.database-migration +++ b/docker/Dockerfile.database-migration @@ -1,5 +1,20 @@ -ARG BASE_IMAGE -FROM ${BASE_IMAGE} +# base image mostly copied from Dockerfile.backend to reuse cache +FROM python:3.12.5-slim AS base + +ARG path=/app +WORKDIR $path + +ADD backend/protocol_rpc/requirements.txt backend/protocol_rpc/requirements.txt +RUN pip install --upgrade pip \ + && pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu \ + && pip install --no-cache-dir -r backend/protocol_rpc/requirements.txt + +ENV TRANSFORMERS_CACHE /home/backend-user/.cache/huggingface + +COPY ../.env . +COPY backend $path/backend + +FROM base AS migration ENV PYTHONPATH "" WORKDIR /app/backend/database_handler From 0461bf1d48ce842833a37410fab415330db2bdd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 12 Sep 2024 10:54:43 -0300 Subject: [PATCH 65/75] improve schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../node/create_nodes/providers_schema.json | 77 +++++++++++++------ tests/unit/test_providers.py | 72 ++++++++++++++++- 2 files changed, 124 insertions(+), 25 deletions(-) diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 52df77790..26e406e09 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -39,7 +39,58 @@ "properties": { "plugin": { "const": "ollama" + } + } + } + }, + { + "if": { + "properties": { + "provider": { "const": "heuristai" } + } + }, + "then": { + "properties": { + "plugin": { + "const": "heuristai" + }, + "model": { + "enum": [ + "mistralai/mixtral-8x7b-instruct", + "meta-llama/llama-2-70b-chat", + "openhermes-2-yi-34b-gptq", + "dolphin-2.9-llama3-8b" + ] + } + } + } + }, + { + "if": { + "properties": { + "provider": { "const": "openai" } + } + }, + "then": { + "properties": { + "plugin": { + "const": "openai" }, + "model": { + "enum": ["gpt-4", "gpt-4o", "gpt-4o-mini"] + } + } + } + }, + + { + "if": { + "properties": { + "plugin": { "const": "ollama" } + } + }, + "then": { + "properties": { "plugin_config": { "type": "object", "additionalProperties": false, @@ -51,9 +102,7 @@ } } }, - "model": { - "enum": ["llama3", "mistral", "gemma"] - }, + "config": { "type": "object", "additionalProperties": false, @@ -179,14 +228,11 @@ { "if": { "properties": { - "provider": { "const": "heuristai" } + "plugin": { "const": "heuristai" } } }, "then": { "properties": { - "plugin": { - "const": "heuristai" - }, "plugin_config": { "type": "object", "additionalProperties": false, @@ -202,14 +248,6 @@ }, "required": ["api_key_env_var", "api_url"] }, - "model": { - "enum": [ - "mistralai/mixtral-8x7b-instruct", - "meta-llama/llama-2-70b-chat", - "openhermes-2-yi-34b-gptq", - "dolphin-2.9-llama3-8b" - ] - }, "config": { "type": "object", "additionalProperties": false, @@ -237,14 +275,11 @@ { "if": { "properties": { - "provider": { "const": "openai" } + "plugin": { "const": "openai" } } }, "then": { "properties": { - "plugin": { - "const": "openai" - }, "plugin_config": { "type": "object", "additionalProperties": false, @@ -256,9 +291,7 @@ }, "required": ["api_key_env_var"] }, - "model": { - "enum": ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4o-mini"] - }, + "config": { "const": "" } diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py index a71859d0d..435416814 100644 --- a/tests/unit/test_providers.py +++ b/tests/unit/test_providers.py @@ -1,9 +1,75 @@ -from backend.node.create_nodes.providers import ( - get_default_providers, -) +import pytest +from backend.domain.types import LLMProvider +from backend.node.create_nodes.providers import get_default_providers, validate_provider def test_default_providers_valid(): providers = get_default_providers() assert len(providers) > 0 + + +@pytest.mark.parametrize( + "llm_provider", + [ + pytest.param( + LLMProvider( + plugin="openai", + provider="custom provider", + model="custom model", + config="", + plugin_config={ + "api_key_env_var": "some api key", + }, + ), + id="custom openai", + ), + pytest.param( + LLMProvider( + plugin="heuristai", + provider="custom provider", + model="custom model", + config={ + "max_tokens": 100, + "temperature": 0.5, + }, + plugin_config={ + "api_key_env_var": "some api key", + "api_url": "http://localhost:8000", + }, + ), + id="custom heuristai", + ), + pytest.param( + LLMProvider( + plugin="ollama", + provider="custom provider", + model="custom model", + config={ + "mirostat": 0, + "mirostat_eta": 0.1, + "microstat_tau": 5, + "num_ctx": 2048, + "num_qga": 8, + "num_gpu": 0, + "num_thread": 2, + "repeat_last_n": 64, + "repeat_penalty": 1.1, + "temprature": 0.8, + "seed": 0, + "stop": "", + "tfs_z": 1.0, + "num_predict": 128, + "top_k": 40, + "top_p": 0.9, + }, + plugin_config={ + "api_url": "http://localhost:8000", + }, + ), + id="custom ollama", + ), + ], +) +def test_validate_provider(llm_provider): + validate_provider(llm_provider) From 95f2ec997d582c9fd9ce0568bbc751915d6395af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 12 Sep 2024 13:51:21 -0300 Subject: [PATCH 66/75] unify HeuristAI and OpenAI plugins MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../heuristai_dolphin-2.9-llama3-8b.json | 2 +- .../heuristai_meta-llamallama-2-70b-chat.json | 2 +- ...ristai_mistralaimixtral-8x7b-instruct.json | 2 +- .../heuristai_openhermes-2-yi-34b-gptq.json | 2 +- .../default_providers/openai_gpt-4.json | 5 +- .../default_providers/openai_gpt-4o-mini.json | 5 +- .../default_providers/openai_gpt-4o.json | 5 +- .../node/create_nodes/providers_schema.json | 46 ++------ backend/node/genvm/llms.py | 77 ++------------ backend/protocol_rpc/requirements.txt | 2 +- .../test_llm_providers_registry.py | 4 +- tests/plugins/__init__.py | 0 tests/plugins/test_llms.py | 100 ++++++++++++++++++ tests/unit/test_providers.py | 13 +-- 14 files changed, 137 insertions(+), 128 deletions(-) create mode 100644 tests/plugins/__init__.py create mode 100644 tests/plugins/test_llms.py diff --git a/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json index b18bead58..5715d0b5e 100644 --- a/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json +++ b/backend/node/create_nodes/default_providers/heuristai_dolphin-2.9-llama3-8b.json @@ -1,6 +1,6 @@ { "provider": "heuristai", - "plugin": "heuristai", + "plugin": "openai", "model": "dolphin-2.9-llama3-8b", "config": { "temperature": 0.75, diff --git a/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json index 15a514b26..a75e99389 100644 --- a/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json +++ b/backend/node/create_nodes/default_providers/heuristai_meta-llamallama-2-70b-chat.json @@ -1,6 +1,6 @@ { "provider": "heuristai", - "plugin": "heuristai", + "plugin": "openai", "model": "meta-llama/llama-2-70b-chat", "config": { "temperature": 0.75, diff --git a/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json index edfa965ea..7b84a0532 100644 --- a/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json +++ b/backend/node/create_nodes/default_providers/heuristai_mistralaimixtral-8x7b-instruct.json @@ -1,6 +1,6 @@ { "provider": "heuristai", - "plugin": "heuristai", + "plugin": "openai", "model": "mistralai/mixtral-8x7b-instruct", "config": { "temperature": 0.75, diff --git a/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json index a4c78975f..bcc2b01f9 100644 --- a/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json +++ b/backend/node/create_nodes/default_providers/heuristai_openhermes-2-yi-34b-gptq.json @@ -1,6 +1,6 @@ { "provider": "heuristai", - "plugin": "heuristai", + "plugin": "openai", "model": "openhermes-2-yi-34b-gptq", "config": { "temperature": 0.75, diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4.json b/backend/node/create_nodes/default_providers/openai_gpt-4.json index 2872fc8be..033955b4c 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4.json @@ -2,8 +2,9 @@ "provider": "openai", "plugin": "openai", "model": "gpt-4", - "config": "", + "config": {}, "plugin_config": { - "api_key_env_var": "OPENAIKEY" + "api_key_env_var": "OPENAIKEY", + "api_url": null } } diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json b/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json index e6a9d71be..c32a606ac 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4o-mini.json @@ -2,8 +2,9 @@ "provider": "openai", "plugin": "openai", "model": "gpt-4o-mini", - "config": "", + "config": {}, "plugin_config": { - "api_key_env_var": "OPENAIKEY" + "api_key_env_var": "OPENAIKEY", + "api_url": null } } diff --git a/backend/node/create_nodes/default_providers/openai_gpt-4o.json b/backend/node/create_nodes/default_providers/openai_gpt-4o.json index 69b187198..f9c3eab95 100644 --- a/backend/node/create_nodes/default_providers/openai_gpt-4o.json +++ b/backend/node/create_nodes/default_providers/openai_gpt-4o.json @@ -2,8 +2,9 @@ "provider": "openai", "plugin": "openai", "model": "gpt-4o", - "config": "", + "config": {}, "plugin_config": { - "api_key_env_var": "OPENAIKEY" + "api_key_env_var": "OPENAIKEY", + "api_url": null } } diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 26e406e09..c70c6828c 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -11,14 +11,7 @@ "type": "string" }, "config": { - "oneOf": [ - { - "type": "object" - }, - { - "const": "" - } - ] + "type": "object" }, "plugin": { "$comment": "plugin to be loaded by the simulator to interact with the provider", @@ -52,7 +45,7 @@ "then": { "properties": { "plugin": { - "const": "heuristai" + "const": "openai" }, "model": { "enum": [ @@ -228,7 +221,7 @@ { "if": { "properties": { - "plugin": { "const": "heuristai" } + "plugin": { "const": "openai" } } }, "then": { @@ -242,8 +235,8 @@ "$comment": "Environment variable that contains the API key" }, "api_url": { - "type": "string", - "$comment": "URL of the API endpoint" + "type": ["string", "null"], + "$comment": "URL of the API endpoint. `null` is used to represent the official OpenAI API" } }, "required": ["api_key_env_var", "api_url"] @@ -266,34 +259,7 @@ "multipleOf": 10, "default": 500 } - }, - "required": ["temperature", "max_tokens"] - } - } - } - }, - { - "if": { - "properties": { - "plugin": { "const": "openai" } - } - }, - "then": { - "properties": { - "plugin_config": { - "type": "object", - "additionalProperties": false, - "properties": { - "api_key_env_var": { - "type": "string", - "$comment": "Environment variable that contains the API key" - } - }, - "required": ["api_key_env_var"] - }, - - "config": { - "const": "" + } } } } diff --git a/backend/node/genvm/llms.py b/backend/node/genvm/llms.py index 7ca265988..0268f46ef 100644 --- a/backend/node/genvm/llms.py +++ b/backend/node/genvm/llms.py @@ -16,6 +16,7 @@ from typing import Optional from openai import OpenAI, Stream from openai.types.chat import ChatCompletionChunk +from urllib.parse import urljoin from dotenv import load_dotenv import requests @@ -37,7 +38,7 @@ async def process_streaming_buffer(buffer: str, chunk: str, regex: str) -> str: async def stream_http_response(url, data): async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=False) + connector=aiohttp.TCPConnector(ssl=False) ) as session: async with session.post(url, json=data, ssl=False) as response: async for chunk in response.content.iter_any(): @@ -50,7 +51,7 @@ async def call_ollama( regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], ) -> str: - url = get_ollama_url("generate") + url = urljoin(node_config[plugin_config_key]["api_url"], "generate") data = {"model": node_config["model"], "prompt": prompt} @@ -82,36 +83,14 @@ async def call_openai( return_streaming_channel: Optional[asyncio.Queue], ) -> str: api_key_env_var = node_config[plugin_config_key]["api_key_env_var"] - client = get_openai_client(os.environ.get(api_key_env_var)) + url = node_config[plugin_config_key]["api_url"] + client = get_openai_client(os.environ.get(api_key_env_var), url) # TODO: OpenAI exceptions need to be caught here stream = get_openai_stream(client, prompt, node_config) return await get_openai_output(stream, regex, return_streaming_channel) -async def call_heuristai( - node_config: dict, - prompt: str, - regex: Optional[str], - return_streaming_channel: Optional[asyncio.Queue], -) -> str: - api_key_env_var = node_config[plugin_config_key]["api_key_env_var"] - url = node_config[plugin_config_key]["api_url"] - client = get_openai_client(os.environ.get(api_key_env_var), os.environ.get(url)) - stream = get_openai_stream(client, prompt, node_config) - # TODO: Get the line below working - # return await get_openai_output(stream, regex, return_streaming_channel) - output = "" - for chunk in stream: - # raise Exception(chunk.json(), dir(chunk), chunk.choices[0].delta.content) - try: - output += chunk.choices[0].delta.content - except Exception: - raise Exception(chunk.json(), dir(chunk)) - # return stream.choices[0].message.content - return output - - def get_openai_client(api_key: str, url: str = None) -> OpenAI: openai_client = None if url: @@ -122,7 +101,7 @@ def get_openai_client(api_key: str, url: str = None) -> OpenAI: def get_openai_stream(client: OpenAI, prompt, node_config): - config = node_config["config"] + config: dict = node_config["config"] if "temperature" in config and "max_tokens" in config: return client.chat.completions.create( model=node_config["model"], @@ -159,11 +138,9 @@ async def get_openai_output( if "done" in chunk_str: return buffer else: - return buffer + break - -def get_ollama_url(endpoint: str) -> str: - return f"{os.environ['OLAMAPROTOCOL']}://{os.environ['OLAMAHOST']}:{os.environ['OLAMAPORT']}/api/{endpoint}" + return buffer class Plugin(ABC): @@ -249,43 +226,6 @@ def is_model_available(self, model: str) -> bool: return True -class HeuristAIPlugin(Plugin): - def __init__(self, plugin_config: dict): - self.api_key_env_var = plugin_config["api_key_env_var"] - self.url = plugin_config["api_url"] - - async def call( - self, - node_config: dict, - prompt: str, - regex: Optional[str], - return_streaming_channel: Optional[asyncio.Queue], - ) -> str: - return await call_heuristai( - node_config, prompt, regex, return_streaming_channel - ) - - def is_available(self) -> bool: - env_var = os.environ.get(self.api_key_env_var) - - return ( - env_var != None - and env_var != "" - and env_var != "" - ) - - def is_model_available(self, model: str) -> bool: - """ - Model checks are done by the shema providers_schema.json - """ - # heuristic_models_result = requests.get(os.environ['HEURISTAIMODELSURL']).json() - # heuristic_models = [] - # for entry in heuristic_models_result: - # heuristic_models.append(entry['name']) - - return True - - def get_llm_plugin(plugin: str, plugin_config: dict) -> Plugin: """ Function to register new providers @@ -293,7 +233,6 @@ def get_llm_plugin(plugin: str, plugin_config: dict) -> Plugin: plugin_map = { "ollama": OllamaPlugin, "openai": OpenAIPlugin, - "heuristai": HeuristAIPlugin, } if plugin not in plugin_map: diff --git a/backend/protocol_rpc/requirements.txt b/backend/protocol_rpc/requirements.txt index 27f8db634..e897f9fd7 100644 --- a/backend/protocol_rpc/requirements.txt +++ b/backend/protocol_rpc/requirements.txt @@ -12,7 +12,7 @@ pytest==8.1.1 colorama==0.4.6 debugpy==1.8.1 aiohttp==3.9.3 -openai==1.16.1 +openai==1.44.1 SQLAlchemy[asyncio]==2.0.31 alembic==1.13.2 eth-account==0.13.1 diff --git a/tests/integration/test_llm_providers_registry.py b/tests/integration/test_llm_providers_registry.py index 7d6ff26db..e7f2e148b 100644 --- a/tests/integration/test_llm_providers_registry.py +++ b/tests/integration/test_llm_providers_registry.py @@ -6,7 +6,7 @@ def test_llm_providers(): provider = { "provider": "openai", "model": "gpt-4", - "config": "", + "config": {}, "plugin": "openai", "plugin_config": {"api_key_env_var": "OPENAIKEY"}, } @@ -19,7 +19,7 @@ def test_llm_providers(): updated_provider = { "provider": "openai", "model": "gpt-4o", - "config": "", + "config": {}, "plugin": "openai", "plugin_config": {"api_key_env_var": "OPENAIKEY"}, } diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/plugins/test_llms.py b/tests/plugins/test_llms.py new file mode 100644 index 000000000..0fe382531 --- /dev/null +++ b/tests/plugins/test_llms.py @@ -0,0 +1,100 @@ +""" +These are integration tests for the LLM Plugins +These tests are not intended to test the actual LLMs, but the plugins that interact with them +These tests will incur costs when LLM services are called +The purpose of these tests is to have a small feedback loop for developing the LLM plugins +""" + +import asyncio +from backend.node.genvm.llms import OllamaPlugin, OpenAIPlugin + + +def test_openai_plugin(): + plugin_config = {"api_key_env_var": "OPENAIKEY", "api_url": None} + node_config = { + "provider": "openai", + "model": "gpt-4o-mini", + "config": {"temperature": 0, "max_tokens": 10}, + "plugin_config": plugin_config, + } + + plugin = OpenAIPlugin(plugin_config) + result = asyncio.run( + plugin.call( + node_config=node_config, + prompt="Once upon a time", + regex=None, + return_streaming_channel=None, + ) + ) + + print(result) + assert result != None and result != "" and isinstance(result, str) + + +def test_heuristai_plugin(): + plugin_config = { + "api_key_env_var": "HEURISTAIAPIKEY", + "api_url": "https://llm-gateway.heurist.xyz", + } + node_config = { + "provider": "heuristai", + "model": "mistralai/mixtral-8x7b-instruct", + "config": {"temperature": 0, "max_tokens": 10}, + "plugin_config": plugin_config, + } + + plugin = OpenAIPlugin(plugin_config) + result = asyncio.run( + plugin.call( + node_config=node_config, + prompt="Once upon a time", + regex=None, + return_streaming_channel=None, + ) + ) + + print(result) + assert result != None and result != "" and isinstance(result, str) + + +def test_ollama_plugin(): + plugin_config = { + "api_url": "http://localhost:11434/api/", + } + node_config = { + "provider": "ollama", + "model": "llama3", + "config": { + "mirostat": 0, + "mirostat_eta": 0.1, + "microstat_tau": 5, + "num_ctx": 2048, + "num_qga": 8, + "num_gpu": 0, + "num_thread": 2, + "repeat_last_n": 64, + "repeat_penalty": 1.1, + "temprature": 0.8, + "seed": 0, + "stop": "", + "tfs_z": 1.0, + "num_predict": 128, + "top_k": 40, + "top_p": 0.9, + }, + "plugin_config": plugin_config, + } + + plugin = OllamaPlugin(plugin_config) + result = asyncio.run( + plugin.call( + node_config=node_config, + prompt="Once upon a time", + regex=None, + return_streaming_channel=None, + ) + ) + + print(result) + assert result != None and result != "" and isinstance(result, str) diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py index 435416814..64ae6d96b 100644 --- a/tests/unit/test_providers.py +++ b/tests/unit/test_providers.py @@ -17,28 +17,29 @@ def test_default_providers_valid(): plugin="openai", provider="custom provider", model="custom model", - config="", + config={}, plugin_config={ "api_key_env_var": "some api key", + "api_url": None, }, ), id="custom openai", ), pytest.param( LLMProvider( - plugin="heuristai", - provider="custom provider", - model="custom model", + plugin="openai", + provider="heuristai", + model="mistralai/mixtral-8x7b-instruct", config={ "max_tokens": 100, "temperature": 0.5, }, plugin_config={ "api_key_env_var": "some api key", - "api_url": "http://localhost:8000", + "api_url": "https://llm-gateway.heurist.xyz", }, ), - id="custom heuristai", + id="heuristai", ), pytest.param( LLMProvider( From 59e2885790f00793399bd1cd0a62471abd1fce83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 12 Sep 2024 15:20:45 -0300 Subject: [PATCH 67/75] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- tests/integration/test_llm_providers_registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_llm_providers_registry.py b/tests/integration/test_llm_providers_registry.py index e7f2e148b..ae2fbcbf7 100644 --- a/tests/integration/test_llm_providers_registry.py +++ b/tests/integration/test_llm_providers_registry.py @@ -8,7 +8,7 @@ def test_llm_providers(): "model": "gpt-4", "config": {}, "plugin": "openai", - "plugin_config": {"api_key_env_var": "OPENAIKEY"}, + "plugin_config": {"api_key_env_var": "OPENAIKEY", "api_url": None}, } # Create a new provider response = post_request_localhost(payload("add_provider", provider)).json() @@ -21,7 +21,7 @@ def test_llm_providers(): "model": "gpt-4o", "config": {}, "plugin": "openai", - "plugin_config": {"api_key_env_var": "OPENAIKEY"}, + "plugin_config": {"api_key_env_var": "OPENAIKEY", "api_url": None}, } # Uodate it response = post_request_localhost( From 74d8acf09ba8e86be7885146e506c475790e3063 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Thu, 12 Sep 2024 16:09:26 -0300 Subject: [PATCH 68/75] start adding anthropic plugin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .env.example | 3 + .../backend_integration_tests_pr.yml | 2 +- .../anthropic_claude-3-5-sonnet-20240620.json | 12 +++ .../anthropic_claude-3-haiku-20240307.json | 12 +++ .../anthropic_claude-3-opus-20240229.json | 12 +++ .../anthropic_claude-3-sonnet-20240229.json | 12 +++ .../node/create_nodes/providers_schema.json | 99 ++++++++++++++++++- backend/node/genvm/llms.py | 60 +++++++++++ backend/protocol_rpc/requirements.txt | 1 + tests/plugins/test_llms.py | 25 ++++- 10 files changed, 234 insertions(+), 4 deletions(-) create mode 100644 backend/node/create_nodes/default_providers/anthropic_claude-3-5-sonnet-20240620.json create mode 100644 backend/node/create_nodes/default_providers/anthropic_claude-3-haiku-20240307.json create mode 100644 backend/node/create_nodes/default_providers/anthropic_claude-3-opus-20240229.json create mode 100644 backend/node/create_nodes/default_providers/anthropic_claude-3-sonnet-20240229.json diff --git a/.env.example b/.env.example index cbe4e5199..54dbe9397 100644 --- a/.env.example +++ b/.env.example @@ -56,6 +56,9 @@ HEURISTAIMODELSURL = 'https://raw.githubusercontent.com/heurist-network/heurist # If you want to use Heurist AI add your key here HEURISTAIAPIKEY = '' +# If you want to use Anthropic (Claude AI) add your key here and uncomment the line +# ANTHROPIC_API_KEY = '' + # Front end container details VITE_JSON_RPC_SERVER_URL = 'http://127.0.0.1:4000/api' VITE_WS_SERVER_URL = 'ws://127.0.0.1:4000' diff --git a/.github/workflows/backend_integration_tests_pr.yml b/.github/workflows/backend_integration_tests_pr.yml index 7dcbe6a30..d931fd26e 100644 --- a/.github/workflows/backend_integration_tests_pr.yml +++ b/.github/workflows/backend_integration_tests_pr.yml @@ -42,7 +42,7 @@ jobs: - name: Copy .env file run: cp .env.example .env - # TODO: we should also add a heuristai key to the e2e tests + # TODO: we should also add also heuristai and anthropic keys to the e2e tests and test all providers - name: Set OPENAIKEY in the .env file so it can be loaded from the environment env: diff --git a/backend/node/create_nodes/default_providers/anthropic_claude-3-5-sonnet-20240620.json b/backend/node/create_nodes/default_providers/anthropic_claude-3-5-sonnet-20240620.json new file mode 100644 index 000000000..c6d7598e8 --- /dev/null +++ b/backend/node/create_nodes/default_providers/anthropic_claude-3-5-sonnet-20240620.json @@ -0,0 +1,12 @@ +{ + "provider": "anthropic", + "plugin": "anthropic", + "model": "claude-3-5-sonnet-20240620", + "config": { + "max_tokens": 500 + }, + "plugin_config": { + "api_key_env_var": "ANTHROPIC_API_KEY", + "api_url": null + } +} diff --git a/backend/node/create_nodes/default_providers/anthropic_claude-3-haiku-20240307.json b/backend/node/create_nodes/default_providers/anthropic_claude-3-haiku-20240307.json new file mode 100644 index 000000000..86f6ddccb --- /dev/null +++ b/backend/node/create_nodes/default_providers/anthropic_claude-3-haiku-20240307.json @@ -0,0 +1,12 @@ +{ + "provider": "anthropic", + "plugin": "anthropic", + "model": "claude-3-haiku-20240307", + "config": { + "max_tokens": 500 + }, + "plugin_config": { + "api_key_env_var": "ANTHROPIC_API_KEY", + "api_url": null + } +} diff --git a/backend/node/create_nodes/default_providers/anthropic_claude-3-opus-20240229.json b/backend/node/create_nodes/default_providers/anthropic_claude-3-opus-20240229.json new file mode 100644 index 000000000..14e0d6773 --- /dev/null +++ b/backend/node/create_nodes/default_providers/anthropic_claude-3-opus-20240229.json @@ -0,0 +1,12 @@ +{ + "provider": "anthropic", + "plugin": "anthropic", + "model": "claude-3-opus-20240229", + "config": { + "max_tokens": 500 + }, + "plugin_config": { + "api_key_env_var": "ANTHROPIC_API_KEY", + "api_url": null + } +} diff --git a/backend/node/create_nodes/default_providers/anthropic_claude-3-sonnet-20240229.json b/backend/node/create_nodes/default_providers/anthropic_claude-3-sonnet-20240229.json new file mode 100644 index 000000000..b28407ea5 --- /dev/null +++ b/backend/node/create_nodes/default_providers/anthropic_claude-3-sonnet-20240229.json @@ -0,0 +1,12 @@ +{ + "provider": "anthropic", + "plugin": "anthropic", + "model": "claude-3-sonnet-20240229", + "config": { + "max_tokens": 500 + }, + "plugin_config": { + "api_key_env_var": "ANTHROPIC_API_KEY", + "api_url": null + } +} diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index c70c6828c..bef05c89b 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -5,7 +5,7 @@ "properties": { "provider": { "type": "string", - "examples": ["ollama", "heuristai", "openai"] + "examples": ["ollama", "heuristai", "openai", "anthropic"] }, "model": { "type": "string" @@ -15,7 +15,7 @@ }, "plugin": { "$comment": "plugin to be loaded by the simulator to interact with the provider", - "enum": ["heuristai", "openai", "ollama"] + "enum": ["heuristai", "openai", "ollama", "anthropic"] }, "plugin_config": { "type": "object" @@ -75,6 +75,28 @@ } } }, + { + "if": { + "properties": { + "provider": { "const": "anthropic" } + } + }, + "then": { + "properties": { + "plugin": { + "const": "anthropic" + }, + "model": { + "enum": [ + "claude-3-5-sonnet-20240620", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307" + ] + } + } + } + }, { "if": { @@ -263,6 +285,79 @@ } } } + }, + { + "if": { + "properties": { + "plugin": { "const": "anthropic" } + } + }, + "then": { + "properties": { + "plugin_config": { + "type": "object", + "additionalProperties": false, + "properties": { + "api_key_env_var": { + "type": "string", + "$comment": "Environment variable that contains the API key" + }, + "api_url": { + "type": ["string", "null"], + "$comment": "URL of the API endpoint. `null` is used to represent the official API" + } + }, + "required": ["api_key_env_var", "api_url"] + }, + "config": { + "type": "object", + "additionalProperties": false, + "properties": { + "temperature": { + "type": "number", + "minimum": 0, + "maximum": 1, + "multipleOf": 0.05, + "default": 0.75 + }, + "max_tokens": { + "type": "integer", + "minimum": 100, + "maximum": 2000, + "multipleOf": 10, + "default": 500 + }, + "top_k": { + "type": "integer", + "minimum": 2, + "maximum": 100, + "default": 40 + }, + "top_p": { + "type": "number", + "minimum": 0.5, + "maximum": 0.99, + "multipleOf": 0.01, + "default": 0.9 + }, + "timeout": { + "type": "integer", + "minimum": 1, + "maximum": 60, + "default": 10 + }, + "stop_sequences": { + "type": "array", + "items": { + "type": "string" + }, + "default": [] + } + }, + "required": ["max_tokens"] + } + } + } } ], "required": ["provider", "model", "config", "plugin", "plugin_config"], diff --git a/backend/node/genvm/llms.py b/backend/node/genvm/llms.py index 0268f46ef..a9bbb3f29 100644 --- a/backend/node/genvm/llms.py +++ b/backend/node/genvm/llms.py @@ -16,6 +16,7 @@ from typing import Optional from openai import OpenAI, Stream from openai.types.chat import ChatCompletionChunk +from anthropic import AsyncAnthropic from urllib.parse import urljoin from dotenv import load_dotenv @@ -226,6 +227,65 @@ def is_model_available(self, model: str) -> bool: return True +class AnthropicPlugin(Plugin): + def __init__(self, plugin_config: dict): + self.api_key_env_var = plugin_config["api_key_env_var"] + self.url = plugin_config["api_url"] + + async def call( + self, + node_config: dict, + prompt: str, + regex: Optional[str], + return_streaming_channel: Optional[asyncio.Queue], + ) -> str: + if self.url: + client = AsyncAnthropic(api_key=self.get_api_key(), base_url=self.url) + else: + client = AsyncAnthropic(api_key=self.get_api_key()) + + buffer = "" + async with client.messages.create( + model=node_config["model"], + messages=[ + { + "role": "user", + "content": prompt, + } + ], + stream=True, + **node_config[ + "config" + ], # max_tokens, temperature, top_k, top_p, timeout, stop_sequences + ) as stream: + async for event in stream: + if event.type == "text": + buffer += event.text + if return_streaming_channel is not None: + await return_streaming_channel.put(event.text) + match = re.search(regex, buffer) + if match: + return match.group(0) + elif event.type == "content_block_stop": + break + + return await stream.get_final_message() + + def is_available(self) -> bool: + env_var = self.get_api_key() + + return env_var != None and env_var != "" + + def is_model_available(self, model: str) -> bool: + """ + Model checks are done by the shema providers_schema.json + """ + return True + + def get_api_key(self): + return os.environ.get(self.api_key_env_var) + + def get_llm_plugin(plugin: str, plugin_config: dict) -> Plugin: """ Function to register new providers diff --git a/backend/protocol_rpc/requirements.txt b/backend/protocol_rpc/requirements.txt index e897f9fd7..9c73e9fb7 100644 --- a/backend/protocol_rpc/requirements.txt +++ b/backend/protocol_rpc/requirements.txt @@ -13,6 +13,7 @@ colorama==0.4.6 debugpy==1.8.1 aiohttp==3.9.3 openai==1.44.1 +anthropic==0.34.2 SQLAlchemy[asyncio]==2.0.31 alembic==1.13.2 eth-account==0.13.1 diff --git a/tests/plugins/test_llms.py b/tests/plugins/test_llms.py index 0fe382531..602797099 100644 --- a/tests/plugins/test_llms.py +++ b/tests/plugins/test_llms.py @@ -6,7 +6,7 @@ """ import asyncio -from backend.node.genvm.llms import OllamaPlugin, OpenAIPlugin +from backend.node.genvm.llms import AnthropicPlugin, OllamaPlugin, OpenAIPlugin def test_openai_plugin(): @@ -98,3 +98,26 @@ def test_ollama_plugin(): print(result) assert result != None and result != "" and isinstance(result, str) + + +def test_anthropic_plugin(): + plugin_config = {"api_key_env_var": "ANTROPIC_API_KEY"} + node_config = { + "provider": "anthropic", + "model": "claude-3-5-sonnet-20240620", + "config": {"max_tokens": 10}, + "plugin_config": plugin_config, + } + + plugin = AnthropicPlugin(plugin_config) + result = asyncio.run( + plugin.call( + node_config=node_config, + prompt="Once upon a time", + regex=None, + return_streaming_channel=None, + ) + ) + + print(result) + assert result != None and result != "" and isinstance(result, str) From 749d8ea850a544441440eba7b112ca94129ee8f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 13 Sep 2024 08:24:53 -0300 Subject: [PATCH 69/75] fix: register anthropic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/providers_schema.json | 2 +- backend/node/genvm/llms.py | 1 + tests/integration/test_validators.py | 4 +++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index bef05c89b..0830b7475 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -15,7 +15,7 @@ }, "plugin": { "$comment": "plugin to be loaded by the simulator to interact with the provider", - "enum": ["heuristai", "openai", "ollama", "anthropic"] + "enum": ["openai", "ollama", "anthropic"] }, "plugin_config": { "type": "object" diff --git a/backend/node/genvm/llms.py b/backend/node/genvm/llms.py index a9bbb3f29..f50b74731 100644 --- a/backend/node/genvm/llms.py +++ b/backend/node/genvm/llms.py @@ -293,6 +293,7 @@ def get_llm_plugin(plugin: str, plugin_config: dict) -> Plugin: plugin_map = { "ollama": OllamaPlugin, "openai": OpenAIPlugin, + "anthropic": AnthropicPlugin, } if plugin not in plugin_map: diff --git a/tests/integration/test_validators.py b/tests/integration/test_validators.py index d3a004c5c..32206a4f5 100644 --- a/tests/integration/test_validators.py +++ b/tests/integration/test_validators.py @@ -1,8 +1,10 @@ +import pytest from tests.common.request import payload, post_request_localhost from tests.common.response import has_success_status -def test_validators(): +@pytest.mark.parametrize("execution_number", range(5)) +def test_validators(execution_number): delete_validators_result = post_request_localhost( payload("delete_all_validators") ).json() From 048886ec09e9f66a562578d51b0d89317dc2bb2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 13 Sep 2024 10:01:29 -0300 Subject: [PATCH 70/75] fix validators in the frontend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/genvm/llms.py | 3 +- backend/protocol_rpc/endpoints.py | 40 ++++++++++++++---------- docker-compose.yml | 2 ++ frontend/src/services/IJsonRpcService.ts | 3 +- frontend/src/services/JsonRpcService.ts | 9 ++++-- frontend/src/stores/node.ts | 12 ++++++- frontend/src/types/index.ts | 2 ++ frontend/src/types/results.ts | 10 ++++++ frontend/test/unit/stores/node.test.ts | 4 +++ 9 files changed, 62 insertions(+), 23 deletions(-) diff --git a/backend/node/genvm/llms.py b/backend/node/genvm/llms.py index f50b74731..be0d82da2 100644 --- a/backend/node/genvm/llms.py +++ b/backend/node/genvm/llms.py @@ -27,7 +27,7 @@ plugin_config_key = "plugin_config" -async def process_streaming_buffer(buffer: str, chunk: str, regex: str) -> str: +async def process_streaming_buffer(buffer: str, chunk: str, regex: str) -> dict: updated_buffer = buffer + chunk if regex: match = re.search(regex, updated_buffer) @@ -65,7 +65,6 @@ async def call_ollama( if return_streaming_channel is not None: if not chunk.get("done"): await return_streaming_channel.put(chunk) - continue else: await return_streaming_channel.put({"done": True}) else: diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index 4097c1723..51afd95a8 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -177,9 +177,9 @@ def create_validator( stake: int, provider: str, model: str, - config: dict | None, - plugin: str | None, - plugin_config: dict | None, + config: dict | None = None, + plugin: str | None = None, + plugin_config: dict | None = None, ) -> dict: # fallback for default provider # TODO: only accept all or none of the config fields @@ -266,27 +266,35 @@ def update_validator( stake: int, provider: str, model: str, - config: dict, - plugin: str, - plugin_config: dict, + config: dict | None = None, + plugin: str | None = None, + plugin_config: dict | None = None, ) -> dict: # Remove validation while adding migration to update the db address # if not accounts_manager.is_valid_address(validator_address): # raise InvalidAddressError(validator_address) - llm_provider = LLMProvider( - provider=provider, - model=model, - config=config, - plugin=plugin, - plugin_config=plugin_config, - id=None, - ) - validate_provider(llm_provider) + + # fallback for default provider + # TODO: only accept all or none of the config fields + llm_provider = None + if not (plugin and plugin_config): + llm_provider = get_default_provider_for(provider, model) + if config: + llm_provider.config = config + else: + llm_provider = LLMProvider( + provider=provider, + model=model, + config=config, + plugin=plugin, + plugin_config=plugin_config, + ) + validate_provider(llm_provider) + validator = Validator( address=validator_address, stake=stake, llmprovider=llm_provider, - id=None, ) return validators_registry.update_validator(validator) diff --git a/docker-compose.yml b/docker-compose.yml index cada9ff55..6b474f43f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,6 +7,8 @@ services: - "${FRONTEND_PORT}:${FRONTEND_PORT}" volumes: - ./examples:/app/src/assets/examples + - ./frontend:/app/frontend + entrypoint: ["npm", "run", "dev"] depends_on: - jsonrpc expose: diff --git a/frontend/src/services/IJsonRpcService.ts b/frontend/src/services/IJsonRpcService.ts index 0db1b0c39..d3a27324a 100644 --- a/frontend/src/services/IJsonRpcService.ts +++ b/frontend/src/services/IJsonRpcService.ts @@ -10,6 +10,7 @@ import type { UpdateValidatorRequest, DeleteValidatorRequest, GetContractSchemaRequest, + GetProvidersAndModelsData, } from '@/types'; export interface IJsonRpcService { @@ -26,7 +27,7 @@ export interface IJsonRpcService { request: GetDeployedContractSchemaRequest, ): Promise>; getValidators(): Promise>; - getProvidersAndModels(): Promise>; + getProvidersAndModels(): Promise>; createValidator(request: CreateValidatorRequest): Promise>; updateValidator(request: UpdateValidatorRequest): Promise>; deleteValidator(request: DeleteValidatorRequest): Promise>; diff --git a/frontend/src/services/JsonRpcService.ts b/frontend/src/services/JsonRpcService.ts index 1b3cdac6e..04493c0f4 100644 --- a/frontend/src/services/JsonRpcService.ts +++ b/frontend/src/services/JsonRpcService.ts @@ -12,6 +12,7 @@ import type { CreateValidatorRequest, UpdateValidatorRequest, TransactionItem, + GetProvidersAndModelsData, } from '@/types'; export class JsonRpcService implements IJsonRpcService { @@ -124,10 +125,12 @@ export class JsonRpcService implements IJsonRpcService { /** * Retrieves a list of providers and models from the JSON-RPC server. * - * @return {Promise>} A promise that resolves to the list of providers and models. + * @return {Promise>} A promise that resolves to the list of providers and models. */ - async getProvidersAndModels(): Promise> { - const { result } = await this.rpcClient.call({ + async getProvidersAndModels(): Promise< + JsonRpcResult + > { + const { result } = await this.rpcClient.call({ method: 'get_providers_and_models', params: [], }); diff --git a/frontend/src/stores/node.ts b/frontend/src/stores/node.ts index 3bbf76325..febbfc682 100644 --- a/frontend/src/stores/node.ts +++ b/frontend/src/stores/node.ts @@ -64,7 +64,17 @@ export const useNodeStore = defineStore('nodeStore', () => { } if (modelsResult?.status === 'success') { - nodeProviders.value = modelsResult.data; + nodeProviders.value = modelsResult.data.reduce( + (acc: Record, llmprovider) => { + const provider = llmprovider.provider; + if (!acc[provider]) { + acc[provider] = []; + } + acc[provider].push(llmprovider.model); + return acc; + }, + {}, + ); } else { throw new Error('Error getting Providers and Models data'); } diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 8aeeaac0c..94c4c358f 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -12,6 +12,8 @@ export interface ValidatorModel { provider: string; stake: number; updated_at: string; + plugin: string; + plugin_config: Record; } export interface NewValidatorDataModel { diff --git a/frontend/src/types/results.ts b/frontend/src/types/results.ts index 394e25125..c5b4f8ebf 100644 --- a/frontend/src/types/results.ts +++ b/frontend/src/types/results.ts @@ -5,3 +5,13 @@ export interface JsonRpcResult { } export interface GetContractStateResult extends Record {} + +export interface GetProvidersAndModelsData + extends Array<{ + config: Record; + id: number; + model: string; + plugin: string; + plugin_config: Record; + provider: string; + }> {} diff --git a/frontend/test/unit/stores/node.test.ts b/frontend/test/unit/stores/node.test.ts index 359a29426..f78b0a0ba 100644 --- a/frontend/test/unit/stores/node.test.ts +++ b/frontend/test/unit/stores/node.test.ts @@ -14,6 +14,8 @@ const testValidator1: ValidatorModel = { model: 'gpt-4', config: '{}', updated_at: new Date().toISOString(), + plugin: 'openai', + plugin_config: {}, }; const testValidator2: ValidatorModel = { @@ -24,6 +26,8 @@ const testValidator2: ValidatorModel = { model: 'llama3', config: '{}', updated_at: new Date().toISOString(), + plugin: 'ollama', + plugin_config: {}, }; const testLog: NodeLog = { From 25b39c9e590b36ed61a955fd8ef958ca1acdb371 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 13 Sep 2024 11:24:06 -0300 Subject: [PATCH 71/75] fix endpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/protocol_rpc/endpoints.py | 8 +++++++- tests/plugins/test_llms.py | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/backend/protocol_rpc/endpoints.py b/backend/protocol_rpc/endpoints.py index 51afd95a8..c297a6aa6 100644 --- a/backend/protocol_rpc/endpoints.py +++ b/backend/protocol_rpc/endpoints.py @@ -99,7 +99,13 @@ def get_contract_schema( validator=Validator( address="", stake=0, - llmprovider=None, + llmprovider=LLMProvider( + provider="", + model="", + config={}, + plugin="", + plugin_config={}, + ), ), leader_receipt=None, msg_handler=msg_handler, diff --git a/tests/plugins/test_llms.py b/tests/plugins/test_llms.py index 602797099..dc3ec3881 100644 --- a/tests/plugins/test_llms.py +++ b/tests/plugins/test_llms.py @@ -6,6 +6,8 @@ """ import asyncio + +import pytest from backend.node.genvm.llms import AnthropicPlugin, OllamaPlugin, OpenAIPlugin @@ -32,14 +34,23 @@ def test_openai_plugin(): assert result != None and result != "" and isinstance(result, str) -def test_heuristai_plugin(): +@pytest.mark.parametrize( + "model", + [ + "mistralai/mixtral-8x7b-instruct", + "meta-llama/llama-2-70b-chat", + "openhermes-2-yi-34b-gptq", + "dolphin-2.9-llama3-8b", + ], +) +def test_heuristai_plugin(model): plugin_config = { "api_key_env_var": "HEURISTAIAPIKEY", "api_url": "https://llm-gateway.heurist.xyz", } node_config = { "provider": "heuristai", - "model": "mistralai/mixtral-8x7b-instruct", + "model": model, "config": {"temperature": 0, "max_tokens": 10}, "plugin_config": plugin_config, } From 770adc674b330b748e35fb3d7ec7402c85c143f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 13 Sep 2024 11:28:57 -0300 Subject: [PATCH 72/75] add new openai models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- .../create_nodes/default_providers/openai_o1-mini.json | 10 ++++++++++ .../default_providers/openai_o1-preview.json | 10 ++++++++++ backend/node/create_nodes/providers_schema.json | 2 +- 3 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 backend/node/create_nodes/default_providers/openai_o1-mini.json create mode 100644 backend/node/create_nodes/default_providers/openai_o1-preview.json diff --git a/backend/node/create_nodes/default_providers/openai_o1-mini.json b/backend/node/create_nodes/default_providers/openai_o1-mini.json new file mode 100644 index 000000000..52fea1f25 --- /dev/null +++ b/backend/node/create_nodes/default_providers/openai_o1-mini.json @@ -0,0 +1,10 @@ +{ + "provider": "openai", + "plugin": "openai", + "model": "o1-mini", + "config": {}, + "plugin_config": { + "api_key_env_var": "OPENAIKEY", + "api_url": null + } +} diff --git a/backend/node/create_nodes/default_providers/openai_o1-preview.json b/backend/node/create_nodes/default_providers/openai_o1-preview.json new file mode 100644 index 000000000..0fc373f8b --- /dev/null +++ b/backend/node/create_nodes/default_providers/openai_o1-preview.json @@ -0,0 +1,10 @@ +{ + "provider": "openai", + "plugin": "openai", + "model": "o1-preview", + "config": {}, + "plugin_config": { + "api_key_env_var": "OPENAIKEY", + "api_url": null + } +} diff --git a/backend/node/create_nodes/providers_schema.json b/backend/node/create_nodes/providers_schema.json index 0830b7475..7b38a8b09 100644 --- a/backend/node/create_nodes/providers_schema.json +++ b/backend/node/create_nodes/providers_schema.json @@ -70,7 +70,7 @@ "const": "openai" }, "model": { - "enum": ["gpt-4", "gpt-4o", "gpt-4o-mini"] + "enum": ["gpt-4", "gpt-4o", "gpt-4o-mini", "o1-preview", "o1-mini"] } } } From 8d2600f20a1efe5ba26d06af3509cc5403301ca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Fri, 13 Sep 2024 12:18:52 -0300 Subject: [PATCH 73/75] fix anthropic plugin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/genvm/llms.py | 34 ++++++++++++++++++++-------------- tests/plugins/test_llms.py | 2 +- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/backend/node/genvm/llms.py b/backend/node/genvm/llms.py index be0d82da2..a43dcbcad 100644 --- a/backend/node/genvm/llms.py +++ b/backend/node/genvm/llms.py @@ -238,13 +238,19 @@ async def call( regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], ) -> str: + client: AsyncAnthropic = None if self.url: client = AsyncAnthropic(api_key=self.get_api_key(), base_url=self.url) else: client = AsyncAnthropic(api_key=self.get_api_key()) + if "max_tokens" not in node_config["config"]: + raise ValueError("`max_tokens` is required for Anthropic") + buffer = "" - async with client.messages.create( + + # Not using `async with` (https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#streaming-helpers) since I get a `'coroutine' object does not support the asynchronous context manager protocol`. Probably related to how the `EquivalencePrinciple` class implements + stream = await client.messages.create( model=node_config["model"], messages=[ { @@ -256,19 +262,19 @@ async def call( **node_config[ "config" ], # max_tokens, temperature, top_k, top_p, timeout, stop_sequences - ) as stream: - async for event in stream: - if event.type == "text": - buffer += event.text - if return_streaming_channel is not None: - await return_streaming_channel.put(event.text) - match = re.search(regex, buffer) - if match: - return match.group(0) - elif event.type == "content_block_stop": - break - - return await stream.get_final_message() + ) + async for event in stream: + if event.type == "text": + buffer += event.text + if return_streaming_channel is not None: + await return_streaming_channel.put(event.text) + match = re.search(regex, buffer) + if match: + return match.group(0) + elif event.type == "content_block_stop": + break + + return buffer def is_available(self) -> bool: env_var = self.get_api_key() diff --git a/tests/plugins/test_llms.py b/tests/plugins/test_llms.py index dc3ec3881..a5450389a 100644 --- a/tests/plugins/test_llms.py +++ b/tests/plugins/test_llms.py @@ -112,7 +112,7 @@ def test_ollama_plugin(): def test_anthropic_plugin(): - plugin_config = {"api_key_env_var": "ANTROPIC_API_KEY"} + plugin_config = {"api_key_env_var": "ANTROPIC_API_KEY", "api_url": None} node_config = { "provider": "anthropic", "model": "claude-3-5-sonnet-20240620", From 6f85c65c7f3e645e191d9faa8a7de6fe169bc79d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Sun, 15 Sep 2024 14:34:45 -0300 Subject: [PATCH 74/75] use protocolt instead of ABC for Plugin interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/genvm/llms.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/backend/node/genvm/llms.py b/backend/node/genvm/llms.py index a43dcbcad..07de32db3 100644 --- a/backend/node/genvm/llms.py +++ b/backend/node/genvm/llms.py @@ -7,7 +7,7 @@ - `return_streaming_channel`: An optional asyncio.Queue to stream the response. """ -from abc import ABC, abstractmethod +from typing import Protocol import os import re import json @@ -143,31 +143,23 @@ async def get_openai_output( return buffer -class Plugin(ABC): - @abstractmethod - def __init__(self, plugin_config: dict): - pass +class Plugin(Protocol): + def __init__(self, plugin_config: dict): ... - @abstractmethod async def call( self, node_config: dict, prompt: str, regex: Optional[str], return_streaming_channel: Optional[asyncio.Queue], - ) -> str: - pass + ) -> str: ... - @abstractmethod - def is_available(self) -> bool: - pass + def is_available(self) -> bool: ... - @abstractmethod - def is_model_available(self, model: str) -> bool: - pass + def is_model_available(self, model: str) -> bool: ... -class OllamaPlugin(Plugin): +class OllamaPlugin: def __init__(self, plugin_config: dict): self.url = plugin_config["api_url"] @@ -197,7 +189,7 @@ def is_model_available(self, model: str) -> bool: return model in installed_ollama_models -class OpenAIPlugin(Plugin): +class OpenAIPlugin: def __init__(self, plugin_config: dict): self.api_key_env_var = plugin_config["api_key_env_var"] @@ -226,7 +218,7 @@ def is_model_available(self, model: str) -> bool: return True -class AnthropicPlugin(Plugin): +class AnthropicPlugin: def __init__(self, plugin_config: dict): self.api_key_env_var = plugin_config["api_key_env_var"] self.url = plugin_config["api_url"] From a16920c4a0a8841312df688cb113df54b0638598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Ramiro=20D=C3=ADaz?= Date: Mon, 16 Sep 2024 15:30:25 -0300 Subject: [PATCH 75/75] feat: load default providers in a thread and cache MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Ramiro Díaz --- backend/node/create_nodes/providers.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/backend/node/create_nodes/providers.py b/backend/node/create_nodes/providers.py index 28018116b..58b53f696 100644 --- a/backend/node/create_nodes/providers.py +++ b/backend/node/create_nodes/providers.py @@ -1,5 +1,6 @@ import json import os +from threading import Thread from typing import List from jsonschema import Draft202012Validator, validate @@ -10,6 +11,8 @@ schema_file = os.path.join(current_directory, "providers_schema.json") default_providers_folder = os.path.join(current_directory, "default_providers") +default_providers_cache: List[LLMProvider] = [] + def get_schema() -> dict: with open(schema_file, "r") as f: @@ -33,6 +36,10 @@ def validate_provider(provider: LLMProvider): def get_default_providers() -> List[LLMProvider]: + global default_providers_cache + if default_providers_cache: + return default_providers_cache + schema = get_schema() files = [ @@ -52,9 +59,15 @@ def get_default_providers() -> List[LLMProvider]: providers.append(_to_domain(provider)) + default_providers_cache = providers return providers +# Start in another thread to avoid blocking the main thread +thread = Thread(target=get_default_providers, args=()) +thread.start() + + def get_default_provider_for(provider: str, model: str) -> LLMProvider: llm_providers = get_default_providers() matches = [