diff --git a/.gitignore b/.gitignore index 0142a0e..1f05ad1 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,5 @@ cython_debug/ poetry.lock pyproject.toml griptape_config.json +pyproject_old.toml +pyproject_pub.toml diff --git a/README.md b/README.md index 7db57a6..2db8422 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,11 @@ You can previous and download more examples [here](examples/README.md). ## Recent Changelog +### Aug 30, 2024 +* Added `max_tokens` to most configuration and prompt_driver nodes. This gives you the ability to control how many tokens come back from the LLM. _Note: It's a known issue that AmazonBedrock doesn't work with max_tokens at the moment._ +* Added `Griptape Tool: Extraction` node that lets you extract either json or csv text with either a json schema or column header definitions. This works well with TaskMemory. +* Added `Griptape Tool: Prompt Summary` node that will summarize text. This works well with TaskMemory. + ### Aug 29, 2024 * Updated griptape version to 0.30.2 - This is a major change to how Griptape handles configurations, but I tried to ensure all nodes and workflows still work. Please let us know if there are any issues. * Added `Griptape Tool: Query` node to allow Task Memory to go "Off Prompt" diff --git a/__init__.py b/__init__.py index ee6f0cb..d8f3fc9 100644 --- a/__init__.py +++ b/__init__.py @@ -219,8 +219,10 @@ from .nodes.tools.gtUICalculator import gtUICalculator from .nodes.tools.gtUIConvertAgentToTool import gtUIConvertAgentToTool from .nodes.tools.gtUIDateTime import gtUIDateTime +from .nodes.tools.gtUIExtractionTool import gtUIExtractionTool from .nodes.tools.gtUIFileManager import gtUIFileManager from .nodes.tools.gtUIKnowledgeBaseTool import gtUIKnowledgeBaseTool +from .nodes.tools.gtUIPromptSummaryTool import gtUIPromptSummaryTool from .nodes.tools.gtUIQueryTool import gtUIQueryTool from .nodes.tools.gtUITextToSpeechClient import gtUITextToSpeechClient from .nodes.tools.gtUIVectorStoreClient import gtUIVectorStoreClient @@ -346,10 +348,12 @@ "Griptape Tool: FileManager": gtUIFileManager, "Griptape Tool: Griptape Cloud KnowledgeBase": gtUIKnowledgeBaseTool, "Griptape Tool: Text to Speech": gtUITextToSpeechClient, - "Griptape Tool: Query": gtUIQueryTool, "Griptape Tool: VectorStore": gtUIVectorStoreClient, "Griptape Tool: WebScraper": gtUIWebScraper, "Griptape Tool: WebSearch": gtUIWebSearch, + "Griptape Tool: Extraction": gtUIExtractionTool, + "Griptape Tool: Prompt Summary": gtUIPromptSummaryTool, + "Griptape Tool: Query": gtUIQueryTool, # DISPLAY "Griptape Display: Image": gtUIOutputImageNode, "Griptape Display: Text": gtUIOutputStringNode, diff --git a/js/ExtractionNodes.js b/js/ExtractionNodes.js new file mode 100644 index 0000000..94d7221 --- /dev/null +++ b/js/ExtractionNodes.js @@ -0,0 +1,46 @@ +import { ComfyWidgets } from "../../../scripts/widgets.js"; +import { fitHeight } from "./utils.js"; +import { formatAndDisplayJSON } from "./gtUIUtils.js"; +import { hideWidget, showWidget } from "./utils.js"; +import { app } from "../../../scripts/app.js"; +export function setupExtractionNodes(nodeType, nodeData, app) { + if (nodeData.name === "Griptape Tool: Extraction") { + setupExtractionTypeAttrr(nodeType, nodeData, app); + } + } + + function setupExtractionTypeAttrr(nodeType, nodeData, app) { + const onNodeCreated = nodeType.prototype.onNodeCreated + nodeType.prototype.onNodeCreated = function() { + const me = onNodeCreated?.apply(this); + const widget_extraction_type = this.widgets.find(w => w.name === 'extraction_type'); + const widget_column_names = this.widgets.find(w=> w.name === 'column_names'); + const widget_template_schema = this.widgets.find(w=> w.name === 'template_schema'); + + // Hide both widgets + widget_extraction_type.callback = async() => { + hideWidget(this, widget_column_names); + hideWidget(this, widget_template_schema); + + switch (widget_extraction_type.value) { + case "csv": + showWidget(widget_column_names); + // fitHeight(this, true); + break; + case "json": + showWidget(widget_template_schema); + // fitHeight(this, true); + break; + default: + // fitHeight(this, true); + break; + } + } + + setTimeout(() => { widget_extraction_type.callback() }, 5); + return me; + // setupMessageStyle(this.message); + }; + + } + diff --git a/js/gtUIMenuSeparator.js b/js/gtUIMenuSeparator.js index 87bc79f..4550b44 100644 --- a/js/gtUIMenuSeparator.js +++ b/js/gtUIMenuSeparator.js @@ -12,6 +12,7 @@ const sep_above_items = [ "Griptape Agent Config: Amazon Bedrock", // Sub Menu Item - Agent Tools "Griptape Tool: Audio Transcription", + "Griptape Tool: Extraction", // Sub Menu Item - Audio "Griptape Load: Audio", // Sub Menu Items - Image diff --git a/js/gtUINodes.js b/js/gtUINodes.js index 0faf540..ad3b9c5 100644 --- a/js/gtUINodes.js +++ b/js/gtUINodes.js @@ -5,6 +5,7 @@ import { setupConfigurationNodes } from "./ConfigurationNodes.js"; import { setupNodeColors } from "./ColorNodes.js"; import { setupDisplayNodes } from "./DisplayNodes.js"; import { setupCombineNodes } from "./CombineNodes.js"; +import { setupExtractionNodes } from "./ExtractionNodes.js"; import { gtUIAddUploadWidget } from "./gtUIUtils.js"; import { setupMenuSeparator } from "./gtUIMenuSeparator.js"; // app.extensionManager.registerSidebarTab({ @@ -57,6 +58,7 @@ app.registerExtension({ setupConfigurationNodes(nodeType, nodeData, app); setupDisplayNodes(nodeType, nodeData, app); setupCombineNodes(nodeType, nodeData, app); + setupExtractionNodes(nodeType, nodeData, app); // Create Audio Node diff --git a/js/utils.js b/js/utils.js index 3954db3..ffb7696 100644 --- a/js/utils.js +++ b/js/utils.js @@ -1,7 +1,43 @@ +// export function fitHeight(node) { +// node.onResize?.(node.size); +// node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]) +// node?.graph?.setDirtyCanvas(true, true); +// } export function fitHeight(node) { - node.onResize?.(node.size); - node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]]) - node?.graph?.setDirtyCanvas(true, true); + if (!node) return null; + + try { + node.onResize?.(node.size); + + // Get the base height from computeSize + let computedHeight = node.computeSize([node.size[0], node.size[1]])[1]; + + // Account for multiline widgets + if (node.widgets) { + for (const widget of node.widgets) { + if (widget.type === "textarea" || widget.options?.multiline) { + // Adjust height based on content + const lines = (widget.value || "").split("\n").length; + const lineHeight = 20; // Adjust this value based on your CSS + const widgetHeight = Math.max(lines * lineHeight, widget.options?.minHeight || 60); + computedHeight += widgetHeight - (widget.options?.minHeight || 60); // Add extra height + } + } + } + + // Set minimum height + computedHeight = Math.max(computedHeight, node.options?.minHeight || 100); + + if (computedHeight !== node.size[1]) { + node.setSize([node.size[0], computedHeight]); + node.graph?.setDirtyCanvas(true, true); + } + + return [node.size[0], computedHeight]; + } catch (error) { + console.error("Error in fitHeight:", error); + return null; + } } export function node_add_dynamic(nodeType, prefix, type='*', count=-1) { @@ -67,3 +103,54 @@ export function node_add_dynamic(nodeType, prefix, type='*', count=-1) { } return nodeType; } + +// TAKEN FROM: ComfyUI\web\extensions\core\widgetInputs.js +// IN CASE someone tries to tell you they invented core functions... +// they simply are not exported +// +const CONVERTED_TYPE = "converted-widget"; + +export function hideWidget(node, widget, suffix = "") { + if (widget.type?.startsWith(CONVERTED_TYPE)) return; + widget.origType = widget.type; + widget.origComputeSize = widget.computeSize; + widget.origSerializeValue = widget.serializeValue; + widget.computeSize = () => [0, -4]; // -4 is due to the gap litegraph adds between widgets automatically + widget.type = CONVERTED_TYPE + suffix; + widget.serializeValue = () => { + // Prevent serializing the widget if we have no input linked + if (!node.inputs) { + return undefined; + } + let node_input = node.inputs.find((i) => i.widget?.name === widget.name); + + if (!node_input || !node_input.link) { + return undefined; + } + return widget.origSerializeValue ? widget.origSerializeValue() : widget.value; + }; + + // Hide any linked widgets, e.g. seed+seedControl + if (widget.linkedWidgets) { + for (const w of widget.linkedWidgets) { + hideWidget(node, w, ":" + widget.name); + } + } +} + +export function showWidget(widget) { + widget.type = widget.origType; + widget.computeSize = widget.origComputeSize; + widget.serializeValue = widget.origSerializeValue; + + delete widget.origType; + delete widget.origComputeSize; + delete widget.origSerializeValue; + + // Hide any linked widgets, e.g. seed+seedControl + if (widget.linkedWidgets) { + for (const w of widget.linkedWidgets) { + showWidget(w); + } + } +} \ No newline at end of file diff --git a/js/versions.js b/js/versions.js index 5073dda..7639009 100644 --- a/js/versions.js +++ b/js/versions.js @@ -1,11 +1,20 @@ export const versions = { - "version": "0.30.2", - "releaseDate": "2024-08-29", + "version": "0.30.2a", + "releaseDate": "2024-08-30", "name": "ComfyUI-Griptape", "description": "Griptape integration for ComfyUI", "author": "Jason Schleifer", "repository": "https://github.com/griptape-ai/ComfyUI-Griptape", "changelog": [ + { + "version": "0.30.2a", + "date": "2024-08-30", + "changes": [ + "Added Griptape Tool: Extraction node", + "Added Griptape Tool: Prompt Summary", + "Addded max_tokens parameter to most prompt models", + ] + }, { "version": "0.30.2", "date": "2024-08-29", diff --git a/nodes/config/gtUIAmazonBedrockStructureConfig.py b/nodes/config/gtUIAmazonBedrockStructureConfig.py index 4018b49..307bf4f 100644 --- a/nodes/config/gtUIAmazonBedrockStructureConfig.py +++ b/nodes/config/gtUIAmazonBedrockStructureConfig.py @@ -65,16 +65,20 @@ def create( self, **kwargs, ): + params = {} + prompt_model = kwargs.get("prompt_model", amazonBedrockPromptModels[0]) temperature = kwargs.get("temperature", 0.7) max_attempts = kwargs.get("max_attempts_on_fail", 10) use_native_tools = kwargs.get("use_native_tools", False) + max_tokens = kwargs.get("max_tokens", None) + params["model"] = prompt_model + params["temperature"] = temperature + params["max_attempts"] = max_attempts + params["use_native_tools"] = use_native_tools + if max_tokens > 0: + params["max_tokens"] = max_tokens custom_config = AmazonBedrockDriversConfig( - prompt_driver=AmazonBedrockPromptDriver( - model=prompt_model, - temperature=temperature, - max_attempts=max_attempts, - use_native_tools=use_native_tools, - ) + prompt_driver=AmazonBedrockPromptDriver(**params) ) return (custom_config,) diff --git a/nodes/config/gtUIAnthropicStructureConfig.py b/nodes/config/gtUIAnthropicStructureConfig.py index 5e24ba3..910364b 100644 --- a/nodes/config/gtUIAnthropicStructureConfig.py +++ b/nodes/config/gtUIAnthropicStructureConfig.py @@ -50,20 +50,20 @@ def create( self, **kwargs, ): - prompt_model = kwargs.get("prompt_model", anthropicPromptModels[0]) - temperature = kwargs.get("temperature", 0.7) - max_attempts = kwargs.get("max_attempts_on_fail", 10) - use_native_tools = kwargs.get("use_native_tools", True) - api_key = self.getenv(kwargs.get("api_key_env_var", DEFAULT_API_KEY)) + params = {} + + params["model"] = kwargs.get("prompt_model", anthropicPromptModels[0]) + params["temperature"] = kwargs.get("temperature", 0.7) + params["max_attempts"] = kwargs.get("max_attempts_on_fail", 10) + params["use_native_tools"] = kwargs.get("use_native_tools", True) + params["api_key"] = self.getenv(kwargs.get("api_key_env_var", DEFAULT_API_KEY)) + + max_tokens = kwargs.get("max_tokens", None) + if max_tokens > 0: + params["max_tokens"] = max_tokens custom_config = AnthropicDriversConfig( - prompt_driver=AnthropicPromptDriver( - model=prompt_model, - temperature=temperature, - max_attempts=max_attempts, - api_key=api_key, - use_native_tools=use_native_tools, - ) + prompt_driver=AnthropicPromptDriver(**params) ) return (custom_config,) diff --git a/nodes/config/gtUIAzureOpenAiStructureConfig.py b/nodes/config/gtUIAzureOpenAiStructureConfig.py index a32ed3b..6c506c8 100644 --- a/nodes/config/gtUIAzureOpenAiStructureConfig.py +++ b/nodes/config/gtUIAzureOpenAiStructureConfig.py @@ -54,14 +54,20 @@ def create( self, **kwargs, ): - prompt_model = kwargs.get("prompt_model", "gpt-4o") - temperature = kwargs.get("temperature", 0.7) - seed = kwargs.get("seed", 12341) + params = {} + params["model"] = kwargs.get("prompt_model", "gpt-4o") + params["temperature"] = kwargs.get("temperature", 0.7) + params["seed"] = kwargs.get("seed", 12341) + params["max_attempts"] = kwargs.get("max_attempts_on_fail", 10) + params["azure_deployment"] = kwargs.get("prompt_model_deployment_name", "gpt4o") + params["use_native_tools"] = kwargs.get("use_native_tools", False) + params["stream"] = kwargs.get("stream", False) image_generation_driver = kwargs.get("image_generation_driver", None) - max_attempts = kwargs.get("max_attempts_on_fail", 10) - prompt_model_deployment_id = kwargs.get("prompt_model_deployment_name", "gpt4o") - use_native_tools = kwargs.get("use_native_tools", False) - stream = kwargs.get("stream", False) + + max_tokens = kwargs.get("max_tokens", -1) + if max_tokens > 0: + params["max_tokens"] = max_tokens + AZURE_OPENAI_API_KEY = self.getenv( kwargs.get("api_key_env_var", DEFAULT_AZURE_OPENAI_API_KEY) ) @@ -69,17 +75,10 @@ def create( kwargs.get("azure_endpoint_env_var", DEFAULT_AZURE_OPENAI_ENDPOINT) ) - prompt_driver = AzureOpenAiChatPromptDriver( - api_key=AZURE_OPENAI_API_KEY, - model=prompt_model, - azure_endpoint=AZURE_OPENAI_ENDPOINT, - azure_deployment=prompt_model_deployment_id, - temperature=temperature, - seed=seed, - max_attempts=max_attempts, - stream=stream, - use_native_tools=use_native_tools, - ) + params["api_key"] = self.getenv("AZURE_OPENAI_API_KEY") + params["azure_endpoint"] = self.getenv("AZURE_OPENAI_ENDPOINT") + + prompt_driver = AzureOpenAiChatPromptDriver(**params) embedding_driver = AzureOpenAiEmbeddingDriver( api_key=self.getenv(AZURE_OPENAI_API_KEY), azure_endpoint=self.getenv(AZURE_OPENAI_ENDPOINT), @@ -89,16 +88,16 @@ def create( image_generation_driver = AzureOpenAiImageGenerationDriver( azure_deployment="dall-e-3", model="dall-e-3", - azure_endpoint=AZURE_OPENAI_ENDPOINT, - api_key=AZURE_OPENAI_API_KEY, + azure_endpoint=params["azure_endpoint"], + api_key=params["api_key"], ) custom_config = AzureOpenAiDriversConfig( prompt_driver=prompt_driver, embedding_driver=embedding_driver, image_generation_driver=image_generation_driver, - azure_endpoint=AZURE_OPENAI_ENDPOINT, - api_key=AZURE_OPENAI_API_KEY, + azure_endpoint=params["azure_endpoint"], + api_key=params["api_key"], ) return (custom_config,) diff --git a/nodes/config/gtUIBaseConfig.py b/nodes/config/gtUIBaseConfig.py index 637bbfe..3b4ecb8 100644 --- a/nodes/config/gtUIBaseConfig.py +++ b/nodes/config/gtUIBaseConfig.py @@ -29,6 +29,13 @@ def INPUT_TYPES(s): # "stream": ([True, False], {"default": False}), "env": ("ENV", {"default": None}), "use_native_tools": ("BOOLEAN", {"default": True}), + "max_tokens": ( + "INT", + { + "default": -1, + "tooltip": "Maximum tokens to generate. If <=0, it will use the default based on the tokenizer.", + }, + ), }, } diff --git a/nodes/config/gtUIGoogleStructureConfig.py b/nodes/config/gtUIGoogleStructureConfig.py index 99ec663..a9b2569 100644 --- a/nodes/config/gtUIGoogleStructureConfig.py +++ b/nodes/config/gtUIGoogleStructureConfig.py @@ -51,11 +51,15 @@ def create( self, **kwargs, ): + params = {} temperature = kwargs.get("temperature", 0.7) prompt_model = kwargs.get("prompt_model", google_models[0]) max_attempts = kwargs.get("max_attempts_on_fail", 10) api_key = self.getenv(kwargs.get("api_key_env_var", DEFAULT_API_KEY)) use_native_tools = kwargs.get("use_native_tools", False) + max_tokens = kwargs.get("max_tokens", -1) + if max_tokens > 0: + params["max_tokens"] = max_tokens custom_config = GoogleDriversConfig( prompt_driver=GooglePromptDriver( @@ -64,6 +68,7 @@ def create( max_attempts=max_attempts, api_key=api_key, use_native_tools=use_native_tools, + **params, ), ) diff --git a/nodes/config/gtUIHuggingFaceStructureConfig.py b/nodes/config/gtUIHuggingFaceStructureConfig.py index e8db2c7..631e634 100644 --- a/nodes/config/gtUIHuggingFaceStructureConfig.py +++ b/nodes/config/gtUIHuggingFaceStructureConfig.py @@ -34,6 +34,8 @@ def INPUT_TYPES(s): ), } ) + del inputs["optional"]["max_tokens"] + return inputs def create(self, **kwargs): @@ -43,7 +45,6 @@ def create(self, **kwargs): max_attempts = kwargs.get("max_attempts_on_fail", 10) use_native_tools = kwargs.get("use_native_tools", False) api_token = self.getenv(kwargs.get("api_token_env_var", DEFAULT_API_KEY)) - configs = {} if prompt_model and api_token: configs["prompt_driver"] = HuggingFaceHubPromptDriver( diff --git a/nodes/config/gtUILMStudioStructureConfig.py b/nodes/config/gtUILMStudioStructureConfig.py index 15945bb..3d9c695 100644 --- a/nodes/config/gtUILMStudioStructureConfig.py +++ b/nodes/config/gtUILMStudioStructureConfig.py @@ -43,25 +43,23 @@ def INPUT_TYPES(s): return inputs def create(self, **kwargs): - model = kwargs.get("model", "") - base_url = kwargs.get("base_url", lmstudio_base_url) + params = {} + params["model"] = kwargs.get("model", "") port = kwargs.get("port", lmstudio_port) - temperature = kwargs.get("temperature", 0.7) - max_attempts = kwargs.get("max_attempts_on_fail", 10) - stream = kwargs.get("stream", False) - seed = kwargs.get("seed", 12341) - use_native_tools = kwargs.get("use_native_tools", False) + base_url = kwargs.get("base_url", lmstudio_base_url) + params["base_url"] = f"{base_url}:{port}/v1" + params["temperature"] = kwargs.get("temperature", 0.7) + params["max_attempts"] = kwargs.get("max_attempts_on_fail", 10) + params["stream"] = kwargs.get("stream", False) + params["seed"] = kwargs.get("seed", 12341) + params["use_native_tools"] = kwargs.get("use_native_tools", False) + + max_tokens = kwargs.get("max_tokens", -1) + if max_tokens > 0: + params["max_tokens"] = max_tokens + custom_config = DriversConfig( - prompt_driver=OpenAiChatPromptDriver( - model=model, - base_url=f"{base_url}:{port}/v1", - api_key="lm_studio", - temperature=temperature, - max_attempts=max_attempts, - stream=stream, - seed=seed, - use_native_tools=use_native_tools, - ), + prompt_driver=OpenAiChatPromptDriver(**params), ) return (custom_config,) diff --git a/nodes/config/gtUIOllamaStructureConfig.py b/nodes/config/gtUIOllamaStructureConfig.py index a18bc0a..08ed4b2 100644 --- a/nodes/config/gtUIOllamaStructureConfig.py +++ b/nodes/config/gtUIOllamaStructureConfig.py @@ -33,24 +33,22 @@ def INPUT_TYPES(s): return inputs def create(self, **kwargs): - prompt_model = kwargs.get("prompt_model", "") - temperature = kwargs.get("temperature", 0.7) - base_url = kwargs.get("base_url", ollama_base_url) - port = kwargs.get("port", ollama_port) - stream = kwargs.get("stream", False) - use_native_tools = kwargs.get("use_native_tools", False) + params = {} - max_attempts = kwargs.get("max_attempts_on_fail", 10) + params["model"] = kwargs.get("prompt_model", "") + params["temperature"] = kwargs.get("temperature", 0.7) + port = kwargs.get("port", ollama_port) + base_url = kwargs.get("base_url", ollama_base_url) + params["host"] = f"{base_url}:{port}" + params["stream"] = kwargs.get("stream", False) + params["use_native_tools"] = kwargs.get("use_native_tools", False) + params["max_attempts"] = kwargs.get("max_attempts_on_fail", 10) + max_tokens = kwargs.get("max_tokens", -1) + if max_tokens > 0: + params["max_tokens"] = max_tokens custom_config = DriversConfig( - prompt_driver=OllamaPromptDriver( - model=prompt_model, - temperature=temperature, - host=f"{base_url}:{port}", - max_attempts=max_attempts, - stream=stream, - use_native_tools=use_native_tools, - ), + prompt_driver=OllamaPromptDriver(**params), ) return (custom_config,) diff --git a/nodes/config/gtUIOpenAiCompatibleConfig.py b/nodes/config/gtUIOpenAiCompatibleConfig.py index 906cb58..7e43181 100644 --- a/nodes/config/gtUIOpenAiCompatibleConfig.py +++ b/nodes/config/gtUIOpenAiCompatibleConfig.py @@ -42,39 +42,38 @@ def INPUT_TYPES(s): return inputs def create(self, **kwargs): - prompt_model = kwargs.get("prompt_model", None) + params = {} + + params["model"] = kwargs.get("prompt_model", None) + params["base_url"] = kwargs.get("base_url", None) + params["max_attempts"] = kwargs.get("max_attempts_on_fail", 10) + params["stream"] = kwargs.get("stream", False) + params["use_native_tools"] = kwargs.get("use_native_tools", False) + api_key_env_var = kwargs.get("api_key_env_var", DEFAULT_API_KEY) + params["api_key"] = self.getenv(api_key_env_var) image_generation_model = kwargs.get("image_generation_model", None) text_to_speech_model = kwargs.get("text_to_speech_model", None) - base_url = kwargs.get("base_url", None) - max_attempts = kwargs.get("max_attempts_on_fail", 10) - stream = kwargs.get("stream", False) - use_native_tools = kwargs.get("use_native_tools", False) - api_key_env_var = kwargs.get("api_key_env_var", DEFAULT_API_KEY) - api_key = self.getenv(api_key_env_var) - if not api_key: - api_key = api_key_env_var + max_tokens = kwargs.get("max_tokens", -1) + if max_tokens > 0: + params["max_tokens"] = max_tokens + + if not params["api_key"]: + params["api_key"] = api_key_env_var configs = {} - if prompt_model and base_url and api_key: - configs["prompt_driver"] = OpenAiChatPromptDriver( - model=prompt_model, - base_url=base_url, - api_key=api_key, - max_attempts=max_attempts, - stream=stream, - use_native_tools=use_native_tools, - ) - if image_generation_model and base_url and api_key: + if params["model"] and params["base_url"] and params["api_key"]: + configs["prompt_driver"] = OpenAiChatPromptDriver(**params) + if image_generation_model and params["base_url"] and params["api_key"]: configs["image_generation_driver"] = OpenAiImageGenerationDriver( model=image_generation_model, - base_url=base_url, - api_key=api_key, + base_url=params["base_url"], + api_key=params["api_key"], ) - if text_to_speech_model and base_url and api_key: + if text_to_speech_model and params["base_url"] and params["api_key"]: configs["text_to_speech_driver"] = OpenAiTextToSpeechDriver( model=text_to_speech_model, - base_url=base_url, - api_key=api_key, + base_url=params["base_url"], + api_key=params["api_key"], ) custom_config = DriversConfig(**configs) return (custom_config,) diff --git a/nodes/config/gtUIOpenAiStructureConfig.py b/nodes/config/gtUIOpenAiStructureConfig.py index c2da006..9e80c8e 100644 --- a/nodes/config/gtUIOpenAiStructureConfig.py +++ b/nodes/config/gtUIOpenAiStructureConfig.py @@ -44,22 +44,20 @@ def create( self, **kwargs, ): - prompt_model = kwargs.get("prompt_model", default_prompt_model) - temperature = kwargs.get("temperature", 0.7) - seed = kwargs.get("seed", 12341) - max_attempts = kwargs.get("max_attempts_on_fail", 10) - api_key = self.getenv(kwargs.get("api_key_env_var", DEFAULT_API_KEY)) - use_native_tools = kwargs.get("use_native_tools", False) + params = {} + params["model"] = kwargs.get("prompt_model", default_prompt_model) + params["temperature"] = kwargs.get("temperature", 0.7) + params["seed"] = kwargs.get("seed", 12341) + params["max_attempts"] = kwargs.get("max_attempts_on_fail", 10) + params["api_key"] = self.getenv(kwargs.get("api_key_env_var", DEFAULT_API_KEY)) + params["use_native_tools"] = kwargs.get("use_native_tools", False) + max_tokens = kwargs.get("max_tokens", -1) + if max_tokens > 0: + params["max_tokens"] = max_tokens + try: Defaults.drivers_config = OpenAiDriversConfig( - prompt_driver=OpenAiChatPromptDriver( - model=prompt_model, - api_key=api_key, - temperature=temperature, - seed=seed, - max_attempts=max_attempts, - use_native_tools=use_native_tools, - ) + prompt_driver=OpenAiChatPromptDriver(**params) ) # OpenAiStructureConfig() diff --git a/nodes/drivers/gtUIAmazonBedrockPromptDriver.py b/nodes/drivers/gtUIAmazonBedrockPromptDriver.py index 390326d..f260987 100644 --- a/nodes/drivers/gtUIAmazonBedrockPromptDriver.py +++ b/nodes/drivers/gtUIAmazonBedrockPromptDriver.py @@ -59,6 +59,8 @@ def create(self, **kwargs): "secret_key_env_var", DEFAULT_AWS_SECRET_ACCESS_KEY ) api_key_env_var = kwargs.get("api_key_env_var", DEFAULT_AWS_ACCESS_KEY_ID) + max_tokens = kwargs.get("max_tokens", 0) + params = {} # Create a boto3 session @@ -76,6 +78,8 @@ def create(self, **kwargs): params["temperature"] = temperature if max_attempts: params["max_attempts"] = max_attempts + if max_tokens > 0: + params["max_tokens"] = max_tokens # if session: # params["session"] = session if use_native_tools: diff --git a/nodes/drivers/gtUIAnthropicPromptDriver.py b/nodes/drivers/gtUIAnthropicPromptDriver.py index cf4e0bc..dd7bd1b 100644 --- a/nodes/drivers/gtUIAnthropicPromptDriver.py +++ b/nodes/drivers/gtUIAnthropicPromptDriver.py @@ -39,6 +39,7 @@ def create(self, **kwargs): max_attempts = kwargs.get("max_attempts_on_fail", None) api_key = self.getenv(kwargs.get("api_key_env_var", DEFAULT_API_KEY)) use_native_tools = kwargs.get("use_native_tools", False) + max_tokens = kwargs.get("max_tokens", None) params = {} if api_key: @@ -53,6 +54,8 @@ def create(self, **kwargs): params["max_attempts"] = max_attempts if use_native_tools: params["use_native_tools"] = use_native_tools + if max_tokens > 0: + params["max_tokens"] = max_tokens try: driver = AnthropicPromptDriver(**params) return (driver,) diff --git a/nodes/drivers/gtUIAzureOpenAiChatPromptDriver.py b/nodes/drivers/gtUIAzureOpenAiChatPromptDriver.py index 14ff02e..09b0fee 100644 --- a/nodes/drivers/gtUIAzureOpenAiChatPromptDriver.py +++ b/nodes/drivers/gtUIAzureOpenAiChatPromptDriver.py @@ -43,6 +43,7 @@ def create(self, **kwargs): max_attempts_on_fail = kwargs.get("max_attempts_on_fail", None) api_key_env_var = kwargs.get("api_key_env_var", DEFAULT_API_KEY_ENV_VAR) use_native_tools = kwargs.get("use_native_tools", False) + max_tokens = kwargs.get("max_tokens", None) azure_endpoint_env_var = kwargs.get( "endpoint_env_var", DEFAULT_AZURE_ENDPOINT_ENV_VAR ) @@ -74,6 +75,8 @@ def create(self, **kwargs): params["max_attempts"] = max_attempts_on_fail if use_native_tools: params["use_native_tools"] = use_native_tools + if max_tokens > 0: + params["max_tokens"] = max_tokens try: driver = AzureOpenAiChatPromptDriver(**params) return (driver,) diff --git a/nodes/drivers/gtUIBasePromptDriver.py b/nodes/drivers/gtUIBasePromptDriver.py index e1b7b9d..78a8026 100644 --- a/nodes/drivers/gtUIBasePromptDriver.py +++ b/nodes/drivers/gtUIBasePromptDriver.py @@ -21,6 +21,13 @@ def INPUT_TYPES(s): ), "seed": ("INT", {"default": 10342349342}), "use_native_tools": ("BOOLEAN", {"default": True}), + "max_tokens": ( + "INT", + { + "default": -1, + "tooltip": "Maximum tokens to generate. If <=0, it will use the default based on the tokenizer.", + }, + ), }, ) return inputs diff --git a/nodes/drivers/gtUICoherePromptDriver.py b/nodes/drivers/gtUICoherePromptDriver.py index ad3e6f5..306d82a 100644 --- a/nodes/drivers/gtUICoherePromptDriver.py +++ b/nodes/drivers/gtUICoherePromptDriver.py @@ -44,6 +44,7 @@ def create(self, **kwargs): stream = kwargs.get("stream", False) max_attempts = kwargs.get("max_attempts_on_fail", None) use_native_tools = kwargs.get("use_native_tools", False) + max_tokens = kwargs.get("max_tokens", None) params = {} if api_key: @@ -56,6 +57,8 @@ def create(self, **kwargs): params["max_attempts"] = max_attempts if use_native_tools: params["use_native_tools"] = use_native_tools + if max_tokens > 0: + params["max_tokens"] = max_tokens try: driver = CoherePromptDriver(**params) return (driver,) diff --git a/nodes/drivers/gtUIGooglePromptDriver.py b/nodes/drivers/gtUIGooglePromptDriver.py index 39f2dc2..5dcf7d5 100644 --- a/nodes/drivers/gtUIGooglePromptDriver.py +++ b/nodes/drivers/gtUIGooglePromptDriver.py @@ -37,6 +37,7 @@ def create(self, **kwargs): temperature = kwargs.get("temperature", None) max_attempts = kwargs.get("max_attempts_on_fail", None) use_native_tools = kwargs.get("use_native_tools", False) + max_tokens = kwargs.get("max_tokens", None) params = {} if api_key: @@ -51,6 +52,8 @@ def create(self, **kwargs): params["max_attempts"] = max_attempts if use_native_tools: params["use_native_tools"] = use_native_tools + if max_tokens > 0: + params["max_tokens"] = max_tokens try: driver = GooglePromptDriver(**params) return (driver,) diff --git a/nodes/drivers/gtUILMStudioChatPromptDriver.py b/nodes/drivers/gtUILMStudioChatPromptDriver.py index b0663b6..43e015d 100644 --- a/nodes/drivers/gtUILMStudioChatPromptDriver.py +++ b/nodes/drivers/gtUILMStudioChatPromptDriver.py @@ -34,6 +34,7 @@ def create(self, **kwargs): temperature = kwargs.get("temperature", None) max_attempts = kwargs.get("max_attempts_on_fail", None) use_native_tools = kwargs.get("use_native_tools", False) + max_tokens = kwargs.get("max_tokens", None) params = {} @@ -49,6 +50,8 @@ def create(self, **kwargs): params["base_url"] = f"{base_url}:{port}/v1" if use_native_tools: params["use_native_tools"] = use_native_tools + if max_tokens > 0: + params["max_tokens"] = max_tokens try: driver = OpenAiChatPromptDriver(**params) diff --git a/nodes/drivers/gtUIOllamaPromptDriver.py b/nodes/drivers/gtUIOllamaPromptDriver.py index c710de7..b7803ca 100644 --- a/nodes/drivers/gtUIOllamaPromptDriver.py +++ b/nodes/drivers/gtUIOllamaPromptDriver.py @@ -32,6 +32,8 @@ def create(self, **kwargs): temperature = kwargs.get("temperature", None) max_attempts = kwargs.get("max_attempts_on_fail", None) use_native_tools = kwargs.get("use_native_tools", False) + max_tokens = kwargs.get("max_tokens", None) + params = {} if model: @@ -46,6 +48,8 @@ def create(self, **kwargs): params["host"] = f"{base_url}:{port}" if use_native_tools: params["use_native_tools"] = use_native_tools + if max_tokens > 0: + params["max_tokens"] = max_tokens try: driver = OllamaPromptDriver(**params) return (driver,) diff --git a/nodes/drivers/gtUIOpenAiChatPromptDriver.py b/nodes/drivers/gtUIOpenAiChatPromptDriver.py index e0e49e0..abe873a 100644 --- a/nodes/drivers/gtUIOpenAiChatPromptDriver.py +++ b/nodes/drivers/gtUIOpenAiChatPromptDriver.py @@ -36,6 +36,7 @@ def create(self, **kwargs): temperature = kwargs.get("temperature", None) max_attempts = kwargs.get("max_attempts_on_fail", None) use_native_tools = kwargs.get("use_native_tools", False) + max_tokens = kwargs.get("max_tokens", None) params = {} if api_key: @@ -54,6 +55,8 @@ def create(self, **kwargs): params["max_attempts"] = max_attempts if use_native_tools: params["use_native_tools"] = use_native_tools + if max_tokens > 0: + params["max_tokens"] = max_tokens try: driver = OpenAiChatPromptDriver(**params) return (driver,) diff --git a/nodes/drivers/gtUIOpenAiCompatibleChatPromptDriver.py b/nodes/drivers/gtUIOpenAiCompatibleChatPromptDriver.py index 7f3be39..cf6cb18 100644 --- a/nodes/drivers/gtUIOpenAiCompatibleChatPromptDriver.py +++ b/nodes/drivers/gtUIOpenAiCompatibleChatPromptDriver.py @@ -37,6 +37,7 @@ def create(self, **kwargs): temperature = kwargs.get("temperature", None) max_attempts = kwargs.get("max_attempts_on_fail", None) use_native_tools = kwargs.get("use_native_tools", False) + max_tokens = kwargs.get("max_tokens", None) params = {} @@ -54,6 +55,8 @@ def create(self, **kwargs): params["api_key"] = self.getenv(api_key_env_var) if use_native_tools: params["use_native_tools"] = use_native_tools + if max_tokens > 0: + params["max_tokens"] = max_tokens try: driver = OpenAiChatPromptDriver(**params) diff --git a/nodes/tools/gtUIExtractionTool.py b/nodes/tools/gtUIExtractionTool.py new file mode 100644 index 0000000..dd1d799 --- /dev/null +++ b/nodes/tools/gtUIExtractionTool.py @@ -0,0 +1,69 @@ +from griptape.drivers import OpenAiChatPromptDriver +from griptape.engines import CsvExtractionEngine, JsonExtractionEngine +from griptape.rules import Rule +from griptape.tools import ExtractionTool + +from .gtUIBaseTool import gtUIBaseTool + +extraction_engines = ["csv", "json"] + + +class gtUIExtractionTool(gtUIBaseTool): + """ + The Griptape Prompt Summary Tool + """ + + @classmethod + def INPUT_TYPES(s): + inputs = super().INPUT_TYPES() + + inputs["optional"].update( + { + "prompt_driver": ("PROMPT_DRIVER", {}), + "extraction_type": (extraction_engines, {"default": "json"}), + "column_names": ( + "STRING", + { + "default": "", + "tooltip": "Comma separated list of column names to extract. Example:\n name, age", + }, + ), + "template_schema": ( + "STRING", + { + "default": "", + "multiline": True, + "tooltip": 'Schema template for the json extraction. Example:\n {"name": str, "age": int}', + }, + ), + } + ) + del inputs["required"]["off_prompt"] + + return inputs + + DESCRIPTION = ( + "Prompt Summary Tool - Summarizes information that is found in Task Memory." + ) + + def create(self, **kwargs): + prompt_driver = kwargs.get("prompt_driver", None) + extraction_type = kwargs.get("extraction_type", "json") + column_names = kwargs.get("column_names", "") + template_schema = kwargs.get("template_schema", "") + params = {} + + if not prompt_driver: + prompt_driver = OpenAiChatPromptDriver() + if extraction_type == "csv": + engine = CsvExtractionEngine( + prompt_driver=prompt_driver, column_names=column_names + ) + elif extraction_type == "json": + engine = JsonExtractionEngine( + prompt_driver=prompt_driver, template_schema=template_schema + ) + + params["extraction_engine"] = engine + tool = ExtractionTool(**params, rules=[Rule("Raw output please")]) + return ([tool],) diff --git a/nodes/tools/gtUIPromptSummaryTool.py b/nodes/tools/gtUIPromptSummaryTool.py new file mode 100644 index 0000000..379071a --- /dev/null +++ b/nodes/tools/gtUIPromptSummaryTool.py @@ -0,0 +1,36 @@ +from griptape.engines import PromptSummaryEngine +from griptape.tools import PromptSummaryTool + +from .gtUIBaseTool import gtUIBaseTool + + +class gtUIPromptSummaryTool(gtUIBaseTool): + """ + The Griptape Prompt Summary Tool + """ + + @classmethod + def INPUT_TYPES(s): + inputs = super().INPUT_TYPES() + + inputs["optional"].update( + { + "prompt_driver": ("PROMPT_DRIVER", {}), + } + ) + del inputs["required"]["off_prompt"] + + return inputs + + DESCRIPTION = ( + "Prompt Summary Tool - Summarizes information that is found in Task Memory." + ) + + def create(self, **kwargs): + prompt_driver = kwargs.get("prompt_driver", None) + + params = {} + engine = PromptSummaryEngine(prompt_driver=prompt_driver) + params["prompt_summary_engine"] = engine + tool = PromptSummaryTool(**params) + return ([tool],)