diff --git a/templates/components/azure-openai.jsonnet b/templates/components/azure-openai.jsonnet index 31e19720..aa759b51 100644 --- a/templates/components/azure-openai.jsonnet +++ b/templates/components/azure-openai.jsonnet @@ -5,7 +5,6 @@ local prompts = import "prompts/mixtral.jsonnet"; { - "azure-openai-token":: "${AZURE_OPENAI_TOKEN}", "azure-openai-model":: "GPT-3.5-Turbo", "azure-openai-max-output-tokens":: 4192, "azure-openai-temperature":: 0.0, @@ -14,6 +13,9 @@ local prompts = import "prompts/mixtral.jsonnet"; create:: function(engine) + local envSecrets = engine.envSecrets("azure-openai-credentials") + .with_env_var("AZURE_TOKEN", "azure-token"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -21,8 +23,6 @@ local prompts = import "prompts/mixtral.jsonnet"; "text-completion-azure-openai", "-p", url.pulsar, - "-k", - $["azure-openai-token"], "-m", $["azure-openai-model"], "-x", @@ -30,39 +30,17 @@ local prompts = import "prompts/mixtral.jsonnet"; "-t", std.toString($["azure-openai-temperature"]), ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ "text-completion-azure", "-p", url.pulsar, - "-k", - $["azure-openai-token"], - "-e", - $["azure-openai-model"], "-x", std.toString($["azure-openai-max-output-tokens"]), "-t", @@ -72,24 +50,35 @@ local prompts = import "prompts/mixtral.jsonnet"; "-o", "non-persistent://tg/response/text-completion-rag-response", ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - - } + }, } + prompts diff --git a/templates/components/azure.jsonnet b/templates/components/azure.jsonnet index 3ee819ee..aacbeac4 100644 --- a/templates/components/azure.jsonnet +++ b/templates/components/azure.jsonnet @@ -5,8 +5,6 @@ local prompts = import "prompts/mixtral.jsonnet"; { - "azure-token":: "${AZURE_TOKEN}", - "azure-endpoint":: "${AZURE_ENDPOINT}", "azure-max-output-tokens":: 4096, "azure-temperature":: 0.0, @@ -14,6 +12,10 @@ local prompts = import "prompts/mixtral.jsonnet"; create:: function(engine) + local envSecrets = engine.envSecrets("azure-credentials") + .with_env_var("AZURE_TOKEN", "azure-token") + .with_env_var("AZURE_ENDPOINT", "azure-endpoint"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -21,48 +23,22 @@ local prompts = import "prompts/mixtral.jsonnet"; "text-completion-azure", "-p", url.pulsar, - "-k", - $["azure-token"], - "-e", - $["azure-endpoint"], "-x", std.toString($["azure-max-output-tokens"]), "-t", std.toString($["azure-temperature"]), ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ "text-completion-azure", "-p", url.pulsar, - "-k", - $["azure-token"], - "-e", - $["azure-endpoint"], "-x", std.toString($["azure-max-output-tokens"]), "-t", @@ -72,23 +48,34 @@ local prompts = import "prompts/mixtral.jsonnet"; "-o", "non-persistent://tg/response/text-completion-rag-response", ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - } } + prompts diff --git a/templates/components/bedrock.jsonnet b/templates/components/bedrock.jsonnet index 1c375621..d519363e 100644 --- a/templates/components/bedrock.jsonnet +++ b/templates/components/bedrock.jsonnet @@ -6,9 +6,6 @@ local chunker = import "chunker-recursive.jsonnet"; { - "aws-id-key":: "${AWS_ID_KEY}", - "aws-secret-key":: "${AWS_SECRET_KEY}", - "aws-region":: "us-west-2", "bedrock-max-output-tokens":: 4096, "bedrock-temperature":: 0.0, "bedrock-model":: "mistral.mixtral-8x7b-instruct-v0:1", @@ -17,6 +14,11 @@ local chunker = import "chunker-recursive.jsonnet"; create:: function(engine) + local envSecrets = engine.envSecrets("bedrock-credentials") + .with_env_var("AWS_ID_KEY", "aws-id-key") + .with_env_var("AWS_SECRET_KEY", "aws-secret-key") + .with_env_var("AWS_REGION", "aws-region"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -24,12 +26,6 @@ local chunker = import "chunker-recursive.jsonnet"; "text-completion-bedrock", "-p", url.pulsar, - "-z", - $["aws-id-key"], - "-k", - $["aws-secret-key"], - "-r", - $["aws-region"], "-x", std.toString($["bedrock-max-output-tokens"]), "-t", @@ -37,41 +33,17 @@ local chunker = import "chunker-recursive.jsonnet"; "-m", $["bedrock-model"], ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ "text-completion-bedrock", "-p", url.pulsar, - "-z", - $["aws-id-key"], - "-k", - $["aws-secret-key"], - "-r", - $["aws-region"], "-x", std.toString($["bedrock-max-output-tokens"]), "-t", @@ -83,24 +55,35 @@ local chunker = import "chunker-recursive.jsonnet"; "-o", "non-persistent://tg/response/text-completion-rag-response", ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - - } + }, } + prompts + chunker diff --git a/templates/components/claude.jsonnet b/templates/components/claude.jsonnet index 0cd190d4..b723a16f 100644 --- a/templates/components/claude.jsonnet +++ b/templates/components/claude.jsonnet @@ -5,7 +5,6 @@ local prompts = import "prompts/mixtral.jsonnet"; { - "claude-key":: "${CLAUDE_KEY}", "claude-max-output-tokens":: 4096, "claude-temperature":: 0.0, @@ -13,6 +12,9 @@ local prompts = import "prompts/mixtral.jsonnet"; create:: function(engine) + local envSecrets = engine.envSecrets("claude-credentials") + .with_env_var("CLAUDE_KEY_TOKEN", "claude-key"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -20,44 +22,22 @@ local prompts = import "prompts/mixtral.jsonnet"; "text-completion-claude", "-p", url.pulsar, - "-k", - $["claude-key"], "-x", std.toString($["claude-max-output-tokens"]), "-t", std.toString($["claude-temperature"]), ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ "text-completion-claude", "-p", url.pulsar, - "-k", - $["claude-key"], "-x", std.toString($["claude-max-output-tokens"]), "-t", @@ -67,24 +47,35 @@ local prompts = import "prompts/mixtral.jsonnet"; "-o", "non-persistent://tg/response/text-completion-rag-response", ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - - } + }, } + prompts diff --git a/templates/components/cohere.jsonnet b/templates/components/cohere.jsonnet index f05cb635..c2027f3c 100644 --- a/templates/components/cohere.jsonnet +++ b/templates/components/cohere.jsonnet @@ -9,13 +9,15 @@ local prompts = import "prompts/mixtral.jsonnet"; "chunk-size":: 150, "chunk-overlap":: 10, - "cohere-key":: "${COHERE_KEY}", "cohere-temperature":: 0.0, "text-completion" +: { create:: function(engine) + local envSecrets = engine.envSecrets("cohere-credentials") + .with_env_var("COHERE_KEY", "cohere-key"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -23,42 +25,19 @@ local prompts = import "prompts/mixtral.jsonnet"; "text-completion-cohere", "-p", url.pulsar, - "-k", - $["cohere-key"], "-t", std.toString($["cohere-temperature"]), ]) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8000, 8000, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ "text-completion-cohere", "-p", url.pulsar, - "-k", - $["cohere-key"], "-t", std.toString($["cohere-temperature"]), "-i", @@ -70,20 +49,30 @@ local prompts = import "prompts/mixtral.jsonnet"; .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8000, 8000, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8000, 8000, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - - } + }, } + prompts diff --git a/templates/components/llamafile.jsonnet b/templates/components/llamafile.jsonnet index 93163a14..d51cda61 100644 --- a/templates/components/llamafile.jsonnet +++ b/templates/components/llamafile.jsonnet @@ -6,12 +6,14 @@ local prompts = import "prompts/slm.jsonnet"; { "llamafile-model":: "LLaMA_CPP", - "llamafile-url":: "${LLAMAFILE_URL}", "text-completion" +: { create:: function(engine) + local envSecrets = engine.envSecrets("llamafile-credentials") + .with_env_var("LLAMAFILE_URL", "llamafile-url"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -21,27 +23,12 @@ local prompts = import "prompts/slm.jsonnet"; url.pulsar, "-m", $["llamafile-model"], - "-r", - $["llamafile-url"], ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - engine.resources([ - containerSet, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ @@ -50,26 +37,40 @@ local prompts = import "prompts/slm.jsonnet"; url.pulsar, "-m", $["llamafile-model"], - "-r", - $["llamafile-url"], "-i", "non-persistent://tg/request/text-completion-rag", "-o", "non-persistent://tg/response/text-completion-rag-response", ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] ); + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] + ); + + local service = + engine.internalService(containerSet) + .with_port(8080, 8080, "metrics"); + + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8080, 8080, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, + service, + serviceRag, ]) - - } + }, } + prompts diff --git a/templates/components/ollama.jsonnet b/templates/components/ollama.jsonnet index b0507cef..3e6cc91e 100644 --- a/templates/components/ollama.jsonnet +++ b/templates/components/ollama.jsonnet @@ -6,12 +6,14 @@ local prompts = import "prompts/slm.jsonnet"; { "ollama-model":: "gemma2:9b", - "ollama-url":: "${OLLAMA_HOST}", "text-completion" +: { create:: function(engine) + local envSecrets = engine.envSecrets("ollama-credentials") + .with_env_var("OLLAMA_HOST", "ollama-host"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -21,32 +23,12 @@ local prompts = import "prompts/slm.jsonnet"; url.pulsar, "-m", $["ollama-model"], - "-r", - $["ollama-url"], ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8080, 8080, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ @@ -55,31 +37,40 @@ local prompts = import "prompts/slm.jsonnet"; url.pulsar, "-m", $["ollama-model"], - "-r", - $["ollama-url"], "-i", "non-persistent://tg/request/text-completion-rag", "-o", "non-persistent://tg/response/text-completion-rag-response", ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8080, 8080, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8080, 8080, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - - } + }, } + prompts diff --git a/templates/components/openai.jsonnet b/templates/components/openai.jsonnet index 3d1a2b73..74290420 100644 --- a/templates/components/openai.jsonnet +++ b/templates/components/openai.jsonnet @@ -5,7 +5,6 @@ local prompts = import "prompts/mixtral.jsonnet"; { - "openai-key":: "${OPENAI_KEY}", "openai-max-output-tokens":: 4096, "openai-temperature":: 0.0, "openai-model":: "GPT-3.5-Turbo", @@ -14,6 +13,9 @@ local prompts = import "prompts/mixtral.jsonnet"; create:: function(engine) + local envSecrets = engine.envSecrets("openai-credentials") + .with_env_var("OPENAI_TOKEN", "openai-token"); + local container = engine.container("text-completion") .with_image(images.trustgraph) @@ -21,8 +23,6 @@ local prompts = import "prompts/mixtral.jsonnet"; "text-completion-openai", "-p", url.pulsar, - "-k", - $["openai-key"], "-x", std.toString($["openai-max-output-tokens"]), "-t", @@ -30,37 +30,17 @@ local prompts = import "prompts/mixtral.jsonnet"; "-m", $["openai-model"], ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); - local containerSet = engine.containers( - "text-completion", [ container ] - ); - - local service = - engine.internalService(containerSet) - .with_port(8080, 8080, "metrics"); - - engine.resources([ - containerSet, - service, - ]) - - }, - - "text-completion-rag" +: { - - create:: function(engine) - - local container = + local containerRag = engine.container("text-completion-rag") .with_image(images.trustgraph) .with_command([ "text-completion-openai", "-p", url.pulsar, - "-k", - $["openai-key"], "-x", std.toString($["openai-max-output-tokens"]), "-t", @@ -72,24 +52,35 @@ local prompts = import "prompts/mixtral.jsonnet"; "-o", "non-persistent://tg/response/text-completion-rag-response", ]) + .with_env_var_secrets(envSecrets) .with_limits("0.5", "128M") .with_reservations("0.1", "128M"); local containerSet = engine.containers( - "text-completion-rag", [ container ] + "text-completion", [ container ] + ); + + local containerSetRag = engine.containers( + "text-completion-rag", [ containerRag ] ); local service = engine.internalService(containerSet) .with_port(8080, 8080, "metrics"); + local serviceRag = + engine.internalService(containerSetRag) + .with_port(8080, 8080, "metrics"); + engine.resources([ + envSecrets, containerSet, + containerSetRag, service, + serviceRag, ]) - - } + }, } + prompts diff --git a/templates/engine/docker-compose.jsonnet b/templates/engine/docker-compose.jsonnet index 4f837ff2..c37f1df0 100644 --- a/templates/engine/docker-compose.jsonnet +++ b/templates/engine/docker-compose.jsonnet @@ -18,12 +18,15 @@ reservations: {}, ports: [], volumes: [], + environment: {}, with_image:: function(x) self + { image: x }, with_command:: function(x) self + { command: x }, - with_environment:: function(x) self + { environment: x }, + with_environment:: function(x) self + { + environment: super.environment + x, + }, with_limits:: function(c, m) self + { limits: { cpus: c, memory: m } }, @@ -45,6 +48,16 @@ ] }, + with_env_var_secrets:: + function(vars) + std.foldl( + function(obj, x) obj.with_environment( + { [x]: "${" + x + "}" } + ), + vars.variables, + self + ), + add:: function() { services +: { [container.name]: { @@ -62,7 +75,7 @@ { command: container.command } else {}) + - (if std.objectHas(container, "environment") then + (if ! std.isEmpty(container.environment) then { environment: container.environment } else {}) + @@ -170,6 +183,27 @@ }, + envSecrets:: function(name) + { + + local volume = self, + + name: name, + + volid:: name, + + variables:: [], + + with_env_var:: + function(name, key) self + { + variables: super.variables + [name], + }, + + add:: function() { + } + + }, + containers:: function(name, containers) { diff --git a/templates/engine/k8s.jsonnet b/templates/engine/k8s.jsonnet index 69aabfd7..2fec0d1f 100644 --- a/templates/engine/k8s.jsonnet +++ b/templates/engine/k8s.jsonnet @@ -10,12 +10,20 @@ reservations: {}, ports: [], volumes: [], + environment: [], with_image:: function(x) self + { image: x }, with_command:: function(x) self + { command: x }, - with_environment:: function(x) self + { environment: x }, + with_environment:: function(x) self + { + environment: super.environment + [ + { + name: v.key, value: v.value + } + for v in std.objectKeysValues(x) + ], + }, with_limits:: function(c, m) self + { limits: { cpu: c, memory: m } }, @@ -37,6 +45,24 @@ ] }, + with_env_var_secrets:: + function(vars) + std.foldl( + function(obj, x) obj + { + environment: super.environment + [{ + name: x, + valueFrom: { + secretKeyRef: { + name: vars.name, + key: vars.keyMap[x], + } + } + }] + }, + vars.variables, + self + ), + add:: function() [ { @@ -97,16 +123,11 @@ (if std.objectHas(container, "command") then { command: container.command } else {}) + - (if std.objectHas(container, "environment") then - { env: [ { - name: e.key, value: e.value - } - for e in - std.objectKeysValues( - container.environment - ) - ] - } + + (if ! std.isEmpty(container.environment) then + { + env: container.environment, + } else {}) + (if std.length(container.volumes) > 0 then @@ -283,6 +304,34 @@ }, + envSecrets:: function(name) + { + + local volume = self, + + name: name, + + variables: [], + keyMap: {}, + + with_size:: function(size) self + { size: size }, + + add:: function() [ + ], + + volRef:: function() { + name: volume.name, + secret: { secretName: volume.name }, + }, + + with_env_var:: + function(name, key) self + { + variables: super.variables + [name], + keyMap: super.keyMap + { [name]: key }, + }, + + }, + containers:: function(name, containers) { diff --git a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py index 0d050261..f06d143e 100755 --- a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py +++ b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py @@ -7,6 +7,7 @@ import boto3 import json from prometheus_client import Histogram +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -21,10 +22,11 @@ default_output_queue = text_completion_response_queue default_subscriber = module default_model = 'mistral.mistral-large-2407-v1:0' -default_region = 'us-west-2' +default_region = os.getenv("AWS_REGION", 'us-west-2') default_temperature = 0.0 default_max_output = 2048 - +default_aws_id = os.getenv("AWS_ID_KEY") +default_aws_secret = os.getenv("AWS_SECRET_KEY") class Processor(ConsumerProducer): @@ -34,12 +36,18 @@ def __init__(self, **params): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) model = params.get("model", default_model) - aws_id = params.get("aws_id_key") - aws_secret = params.get("aws_secret") + aws_id = params.get("aws_id_key", default_aws_id) + aws_secret = params.get("aws_secret", default_aws_secret) aws_region = params.get("aws_region", default_region) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + if aws_id is None: + raise RuntimeError("AWS ID not specified") + + if aws_secret is None: + raise RuntimeError("AWS secret not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -299,7 +307,7 @@ def add_args(parser): parser.add_argument( '-r', '--aws-region', - help=f'AWS Region (default: us-west-2)' + help=f'AWS Region' ) parser.add_argument( @@ -320,4 +328,3 @@ def run(): Processor.start(module, __doc__) - diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py index ff97f644..949eeae5 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure/llm.py @@ -7,6 +7,7 @@ import requests import json from prometheus_client import Histogram +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -23,6 +24,8 @@ default_temperature = 0.0 default_max_output = 4192 default_model = "AzureAI" +default_endpoint = os.getenv("AZURE_ENDPOINT") +default_token = os.getenv("AZURE_TOKEN") class Processor(ConsumerProducer): @@ -31,12 +34,18 @@ def __init__(self, **params): input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) - endpoint = params.get("endpoint") - token = params.get("token") + endpoint = params.get("endpoint", default_endpoint) + token = params.get("token", default_token) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) model = default_model + if endpoint is None: + raise RuntimeError("Azure endpoint not specified") + + if token is None: + raise RuntimeError("Azure token not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, diff --git a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py index d0939d1e..1dbfba27 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/azure_openai/llm.py @@ -8,6 +8,7 @@ import json from prometheus_client import Histogram from openai import AzureOpenAI +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -24,6 +25,8 @@ default_temperature = 0.0 default_max_output = 4192 default_api = "2024-02-15-preview" +default_endpoint = os.getenv("AZURE_ENDPOINT") +default_token = os.getenv("AZURE_TOKEN") class Processor(ConsumerProducer): @@ -32,13 +35,19 @@ def __init__(self, **params): input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) - endpoint = params.get("endpoint") - token = params.get("token") + endpoint = params.get("endpoint", default_endpoint) + token = params.get("token", default_token) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) model = params.get("model") api = params.get("api_version", default_api) + if endpoint is None: + raise RuntimeError("Azure endpoint not specified") + + if token is None: + raise RuntimeError("Azure token not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, diff --git a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py index ad949b02..08e2828d 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/claude/llm.py @@ -6,6 +6,7 @@ import anthropic from prometheus_client import Histogram +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -22,6 +23,7 @@ default_model = 'claude-3-5-sonnet-20240620' default_temperature = 0.0 default_max_output = 8192 +default_api_key = os.getenv("CLAUDE_KEY") class Processor(ConsumerProducer): @@ -31,10 +33,13 @@ def __init__(self, **params): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) model = params.get("model", default_model) - api_key = params.get("api_key") + api_key = params.get("api_key", default_api_key) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + if api_key is None: + raise RuntimeError("Claude API key not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, diff --git a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py index 4c64e8b6..8e1f4f7c 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py @@ -6,6 +6,7 @@ import cohere from prometheus_client import Histogram +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -21,6 +22,7 @@ default_subscriber = module default_model = 'c4ai-aya-23-8b' default_temperature = 0.0 +default_api_key = os.getenv("COHERE_KEY") class Processor(ConsumerProducer): @@ -30,9 +32,12 @@ def __init__(self, **params): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) model = params.get("model", default_model) - api_key = params.get("api_key") + api_key = params.get("api_key", default_api_key) temperature = params.get("temperature", default_temperature) + if api_key is None: + raise RuntimeError("Cohere API key not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, diff --git a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py index baa00b04..47a75927 100644 --- a/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/googleaistudio/llm.py @@ -7,6 +7,7 @@ import google.generativeai as genai from google.generativeai.types import HarmCategory, HarmBlockThreshold from prometheus_client import Histogram +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -23,6 +24,7 @@ default_model = 'gemini-1.5-flash-002' default_temperature = 0.0 default_max_output = 8192 +default_api_key = os.getenv("GOOGLE_AI_STUDIO_KEY") class Processor(ConsumerProducer): @@ -32,10 +34,13 @@ def __init__(self, **params): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) model = params.get("model", default_model) - api_key = params.get("api_key") + api_key = params.get("api_key", default_api_key) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + if api_key is None: + raise RuntimeError("Google AI Studio API key not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -211,4 +216,4 @@ def run(): Processor.start(module, __doc__) - \ No newline at end of file + diff --git a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py index 86427167..a25ad8ec 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py @@ -20,7 +20,7 @@ default_output_queue = text_completion_response_queue default_subscriber = module default_model = 'LLaMA_CPP' -default_llamafile = 'http://localhost:8080/v1' +default_llamafile = os.getenv("LLAMAFILE_URL", "http://localhost:8080/v1") default_temperature = 0.0 default_max_output = 4096 diff --git a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py index b506b3cd..17151a00 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py @@ -6,6 +6,7 @@ from ollama import Client from prometheus_client import Histogram, Info +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -19,8 +20,8 @@ default_input_queue = text_completion_request_queue default_output_queue = text_completion_response_queue default_subscriber = module -default_model = 'gemma2' -default_ollama = 'http://localhost:11434' +default_model = 'gemma2:9b' +default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434') class Processor(ConsumerProducer): @@ -152,7 +153,7 @@ def add_args(parser): parser.add_argument( '-m', '--model', default="gemma2", - help=f'LLM model (default: gemma2)' + help=f'LLM model (default: {default_model})' ) parser.add_argument( diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index 5d259e7e..b63ad43b 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -6,6 +6,7 @@ from openai import OpenAI from prometheus_client import Histogram +import os from .... schema import TextCompletionRequest, TextCompletionResponse, Error from .... schema import text_completion_request_queue @@ -22,6 +23,7 @@ default_model = 'gpt-3.5-turbo' default_temperature = 0.0 default_max_output = 4096 +default_api_key = os.getenv("OPENAI_KEY") class Processor(ConsumerProducer): @@ -31,10 +33,13 @@ def __init__(self, **params): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) model = params.get("model", default_model) - api_key = params.get("api_key") + api_key = params.get("api_key", default_api_key) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + if api_key is None: + raise RuntimeError("OpenAI API key not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue, diff --git a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py index c57b9fb0..76d517a1 100755 --- a/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py +++ b/trustgraph-vertexai/trustgraph/model/text_completion/vertexai/llm.py @@ -7,6 +7,7 @@ import vertexai import time from prometheus_client import Histogram +import os from google.oauth2 import service_account import google @@ -38,6 +39,7 @@ default_region = 'us-central1' default_temperature = 0.0 default_max_output = 8192 +default_private_key = os.getenv("VERTEXAI_KEY") class Processor(ConsumerProducer): @@ -48,10 +50,13 @@ def __init__(self, **params): subscriber = params.get("subscriber", default_subscriber) region = params.get("region", default_region) model = params.get("model", default_model) - private_key = params.get("private_key") + private_key = params.get("private_key", default_private_key) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) + if private_key is None: + raise RuntimeError("Private key file not specified") + super(Processor, self).__init__( **params | { "input_queue": input_queue,