Skip to content

Commit

Permalink
Merge pull request #113 from griptape-ai:dev
Browse files Browse the repository at this point in the history
* Added `max_tokens` to most configuration and prompt_driver nodes. This gives you the ability to control how many tokens come back from the LLM. _Note: It's a known issue that AmazonBedrock doesn't work with max_tokens at the moment._
* Added `Griptape Tool: Extraction` node that lets you extract either json or csv text with either a json schema or column header definitions. This works well with TaskMemory.
* Added `Griptape Tool: Prompt Summary` node that will summarize text. This works well with TaskMemory.
  • Loading branch information
shhlife authored Aug 29, 2024
2 parents 868b6f1 + e14aaab commit 898a7fa
Show file tree
Hide file tree
Showing 30 changed files with 423 additions and 117 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,5 @@ cython_debug/
poetry.lock
pyproject.toml
griptape_config.json
pyproject_old.toml
pyproject_pub.toml
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ You can previous and download more examples [here](examples/README.md).

## Recent Changelog

### Aug 30, 2024
* Added `max_tokens` to most configuration and prompt_driver nodes. This gives you the ability to control how many tokens come back from the LLM. _Note: It's a known issue that AmazonBedrock doesn't work with max_tokens at the moment._
* Added `Griptape Tool: Extraction` node that lets you extract either json or csv text with either a json schema or column header definitions. This works well with TaskMemory.
* Added `Griptape Tool: Prompt Summary` node that will summarize text. This works well with TaskMemory.

### Aug 29, 2024
* Updated griptape version to 0.30.2 - This is a major change to how Griptape handles configurations, but I tried to ensure all nodes and workflows still work. Please let us know if there are any issues.
* Added `Griptape Tool: Query` node to allow Task Memory to go "Off Prompt"
Expand Down
6 changes: 5 additions & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@
from .nodes.tools.gtUICalculator import gtUICalculator
from .nodes.tools.gtUIConvertAgentToTool import gtUIConvertAgentToTool
from .nodes.tools.gtUIDateTime import gtUIDateTime
from .nodes.tools.gtUIExtractionTool import gtUIExtractionTool
from .nodes.tools.gtUIFileManager import gtUIFileManager
from .nodes.tools.gtUIKnowledgeBaseTool import gtUIKnowledgeBaseTool
from .nodes.tools.gtUIPromptSummaryTool import gtUIPromptSummaryTool
from .nodes.tools.gtUIQueryTool import gtUIQueryTool
from .nodes.tools.gtUITextToSpeechClient import gtUITextToSpeechClient
from .nodes.tools.gtUIVectorStoreClient import gtUIVectorStoreClient
Expand Down Expand Up @@ -346,10 +348,12 @@
"Griptape Tool: FileManager": gtUIFileManager,
"Griptape Tool: Griptape Cloud KnowledgeBase": gtUIKnowledgeBaseTool,
"Griptape Tool: Text to Speech": gtUITextToSpeechClient,
"Griptape Tool: Query": gtUIQueryTool,
"Griptape Tool: VectorStore": gtUIVectorStoreClient,
"Griptape Tool: WebScraper": gtUIWebScraper,
"Griptape Tool: WebSearch": gtUIWebSearch,
"Griptape Tool: Extraction": gtUIExtractionTool,
"Griptape Tool: Prompt Summary": gtUIPromptSummaryTool,
"Griptape Tool: Query": gtUIQueryTool,
# DISPLAY
"Griptape Display: Image": gtUIOutputImageNode,
"Griptape Display: Text": gtUIOutputStringNode,
Expand Down
46 changes: 46 additions & 0 deletions js/ExtractionNodes.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { ComfyWidgets } from "../../../scripts/widgets.js";
import { fitHeight } from "./utils.js";
import { formatAndDisplayJSON } from "./gtUIUtils.js";
import { hideWidget, showWidget } from "./utils.js";
import { app } from "../../../scripts/app.js";
export function setupExtractionNodes(nodeType, nodeData, app) {
if (nodeData.name === "Griptape Tool: Extraction") {
setupExtractionTypeAttrr(nodeType, nodeData, app);
}
}

function setupExtractionTypeAttrr(nodeType, nodeData, app) {
const onNodeCreated = nodeType.prototype.onNodeCreated
nodeType.prototype.onNodeCreated = function() {
const me = onNodeCreated?.apply(this);
const widget_extraction_type = this.widgets.find(w => w.name === 'extraction_type');
const widget_column_names = this.widgets.find(w=> w.name === 'column_names');
const widget_template_schema = this.widgets.find(w=> w.name === 'template_schema');

// Hide both widgets
widget_extraction_type.callback = async() => {
hideWidget(this, widget_column_names);
hideWidget(this, widget_template_schema);

switch (widget_extraction_type.value) {
case "csv":
showWidget(widget_column_names);
// fitHeight(this, true);
break;
case "json":
showWidget(widget_template_schema);
// fitHeight(this, true);
break;
default:
// fitHeight(this, true);
break;
}
}

setTimeout(() => { widget_extraction_type.callback() }, 5);
return me;
// setupMessageStyle(this.message);
};

}

1 change: 1 addition & 0 deletions js/gtUIMenuSeparator.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const sep_above_items = [
"Griptape Agent Config: Amazon Bedrock",
// Sub Menu Item - Agent Tools
"Griptape Tool: Audio Transcription",
"Griptape Tool: Extraction",
// Sub Menu Item - Audio
"Griptape Load: Audio",
// Sub Menu Items - Image
Expand Down
2 changes: 2 additions & 0 deletions js/gtUINodes.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { setupConfigurationNodes } from "./ConfigurationNodes.js";
import { setupNodeColors } from "./ColorNodes.js";
import { setupDisplayNodes } from "./DisplayNodes.js";
import { setupCombineNodes } from "./CombineNodes.js";
import { setupExtractionNodes } from "./ExtractionNodes.js";
import { gtUIAddUploadWidget } from "./gtUIUtils.js";
import { setupMenuSeparator } from "./gtUIMenuSeparator.js";
// app.extensionManager.registerSidebarTab({
Expand Down Expand Up @@ -57,6 +58,7 @@ app.registerExtension({
setupConfigurationNodes(nodeType, nodeData, app);
setupDisplayNodes(nodeType, nodeData, app);
setupCombineNodes(nodeType, nodeData, app);
setupExtractionNodes(nodeType, nodeData, app);


// Create Audio Node
Expand Down
93 changes: 90 additions & 3 deletions js/utils.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,43 @@
// export function fitHeight(node) {
// node.onResize?.(node.size);
// node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]])
// node?.graph?.setDirtyCanvas(true, true);
// }
export function fitHeight(node) {
node.onResize?.(node.size);
node.setSize([node.size[0], node.computeSize([node.size[0], node.size[1]])[1]])
node?.graph?.setDirtyCanvas(true, true);
if (!node) return null;

try {
node.onResize?.(node.size);

// Get the base height from computeSize
let computedHeight = node.computeSize([node.size[0], node.size[1]])[1];

// Account for multiline widgets
if (node.widgets) {
for (const widget of node.widgets) {
if (widget.type === "textarea" || widget.options?.multiline) {
// Adjust height based on content
const lines = (widget.value || "").split("\n").length;
const lineHeight = 20; // Adjust this value based on your CSS
const widgetHeight = Math.max(lines * lineHeight, widget.options?.minHeight || 60);
computedHeight += widgetHeight - (widget.options?.minHeight || 60); // Add extra height
}
}
}

// Set minimum height
computedHeight = Math.max(computedHeight, node.options?.minHeight || 100);

if (computedHeight !== node.size[1]) {
node.setSize([node.size[0], computedHeight]);
node.graph?.setDirtyCanvas(true, true);
}

return [node.size[0], computedHeight];
} catch (error) {
console.error("Error in fitHeight:", error);
return null;
}
}

export function node_add_dynamic(nodeType, prefix, type='*', count=-1) {
Expand Down Expand Up @@ -67,3 +103,54 @@ export function node_add_dynamic(nodeType, prefix, type='*', count=-1) {
}
return nodeType;
}

// TAKEN FROM: ComfyUI\web\extensions\core\widgetInputs.js
// IN CASE someone tries to tell you they invented core functions...
// they simply are not exported
//
const CONVERTED_TYPE = "converted-widget";

export function hideWidget(node, widget, suffix = "") {
if (widget.type?.startsWith(CONVERTED_TYPE)) return;
widget.origType = widget.type;
widget.origComputeSize = widget.computeSize;
widget.origSerializeValue = widget.serializeValue;
widget.computeSize = () => [0, -4]; // -4 is due to the gap litegraph adds between widgets automatically
widget.type = CONVERTED_TYPE + suffix;
widget.serializeValue = () => {
// Prevent serializing the widget if we have no input linked
if (!node.inputs) {
return undefined;
}
let node_input = node.inputs.find((i) => i.widget?.name === widget.name);

if (!node_input || !node_input.link) {
return undefined;
}
return widget.origSerializeValue ? widget.origSerializeValue() : widget.value;
};

// Hide any linked widgets, e.g. seed+seedControl
if (widget.linkedWidgets) {
for (const w of widget.linkedWidgets) {
hideWidget(node, w, ":" + widget.name);
}
}
}

export function showWidget(widget) {
widget.type = widget.origType;
widget.computeSize = widget.origComputeSize;
widget.serializeValue = widget.origSerializeValue;

delete widget.origType;
delete widget.origComputeSize;
delete widget.origSerializeValue;

// Hide any linked widgets, e.g. seed+seedControl
if (widget.linkedWidgets) {
for (const w of widget.linkedWidgets) {
showWidget(w);
}
}
}
13 changes: 11 additions & 2 deletions js/versions.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
export const versions = {
"version": "0.30.2",
"releaseDate": "2024-08-29",
"version": "0.30.2a",
"releaseDate": "2024-08-30",
"name": "ComfyUI-Griptape",
"description": "Griptape integration for ComfyUI",
"author": "Jason Schleifer",
"repository": "https://github.com/griptape-ai/ComfyUI-Griptape",
"changelog": [
{
"version": "0.30.2a",
"date": "2024-08-30",
"changes": [
"Added Griptape Tool: Extraction node",
"Added Griptape Tool: Prompt Summary",
"Addded max_tokens parameter to most prompt models",
]
},
{
"version": "0.30.2",
"date": "2024-08-29",
Expand Down
16 changes: 10 additions & 6 deletions nodes/config/gtUIAmazonBedrockStructureConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,20 @@ def create(
self,
**kwargs,
):
params = {}

prompt_model = kwargs.get("prompt_model", amazonBedrockPromptModels[0])
temperature = kwargs.get("temperature", 0.7)
max_attempts = kwargs.get("max_attempts_on_fail", 10)
use_native_tools = kwargs.get("use_native_tools", False)
max_tokens = kwargs.get("max_tokens", None)
params["model"] = prompt_model
params["temperature"] = temperature
params["max_attempts"] = max_attempts
params["use_native_tools"] = use_native_tools
if max_tokens > 0:
params["max_tokens"] = max_tokens
custom_config = AmazonBedrockDriversConfig(
prompt_driver=AmazonBedrockPromptDriver(
model=prompt_model,
temperature=temperature,
max_attempts=max_attempts,
use_native_tools=use_native_tools,
)
prompt_driver=AmazonBedrockPromptDriver(**params)
)
return (custom_config,)
24 changes: 12 additions & 12 deletions nodes/config/gtUIAnthropicStructureConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ def create(
self,
**kwargs,
):
prompt_model = kwargs.get("prompt_model", anthropicPromptModels[0])
temperature = kwargs.get("temperature", 0.7)
max_attempts = kwargs.get("max_attempts_on_fail", 10)
use_native_tools = kwargs.get("use_native_tools", True)
api_key = self.getenv(kwargs.get("api_key_env_var", DEFAULT_API_KEY))
params = {}

params["model"] = kwargs.get("prompt_model", anthropicPromptModels[0])
params["temperature"] = kwargs.get("temperature", 0.7)
params["max_attempts"] = kwargs.get("max_attempts_on_fail", 10)
params["use_native_tools"] = kwargs.get("use_native_tools", True)
params["api_key"] = self.getenv(kwargs.get("api_key_env_var", DEFAULT_API_KEY))

max_tokens = kwargs.get("max_tokens", None)
if max_tokens > 0:
params["max_tokens"] = max_tokens

custom_config = AnthropicDriversConfig(
prompt_driver=AnthropicPromptDriver(
model=prompt_model,
temperature=temperature,
max_attempts=max_attempts,
api_key=api_key,
use_native_tools=use_native_tools,
)
prompt_driver=AnthropicPromptDriver(**params)
)

return (custom_config,)
43 changes: 21 additions & 22 deletions nodes/config/gtUIAzureOpenAiStructureConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,32 +54,31 @@ def create(
self,
**kwargs,
):
prompt_model = kwargs.get("prompt_model", "gpt-4o")
temperature = kwargs.get("temperature", 0.7)
seed = kwargs.get("seed", 12341)
params = {}
params["model"] = kwargs.get("prompt_model", "gpt-4o")
params["temperature"] = kwargs.get("temperature", 0.7)
params["seed"] = kwargs.get("seed", 12341)
params["max_attempts"] = kwargs.get("max_attempts_on_fail", 10)
params["azure_deployment"] = kwargs.get("prompt_model_deployment_name", "gpt4o")
params["use_native_tools"] = kwargs.get("use_native_tools", False)
params["stream"] = kwargs.get("stream", False)
image_generation_driver = kwargs.get("image_generation_driver", None)
max_attempts = kwargs.get("max_attempts_on_fail", 10)
prompt_model_deployment_id = kwargs.get("prompt_model_deployment_name", "gpt4o")
use_native_tools = kwargs.get("use_native_tools", False)
stream = kwargs.get("stream", False)

max_tokens = kwargs.get("max_tokens", -1)
if max_tokens > 0:
params["max_tokens"] = max_tokens

AZURE_OPENAI_API_KEY = self.getenv(
kwargs.get("api_key_env_var", DEFAULT_AZURE_OPENAI_API_KEY)
)
AZURE_OPENAI_ENDPOINT = self.getenv(
kwargs.get("azure_endpoint_env_var", DEFAULT_AZURE_OPENAI_ENDPOINT)
)

prompt_driver = AzureOpenAiChatPromptDriver(
api_key=AZURE_OPENAI_API_KEY,
model=prompt_model,
azure_endpoint=AZURE_OPENAI_ENDPOINT,
azure_deployment=prompt_model_deployment_id,
temperature=temperature,
seed=seed,
max_attempts=max_attempts,
stream=stream,
use_native_tools=use_native_tools,
)
params["api_key"] = self.getenv("AZURE_OPENAI_API_KEY")
params["azure_endpoint"] = self.getenv("AZURE_OPENAI_ENDPOINT")

prompt_driver = AzureOpenAiChatPromptDriver(**params)
embedding_driver = AzureOpenAiEmbeddingDriver(
api_key=self.getenv(AZURE_OPENAI_API_KEY),
azure_endpoint=self.getenv(AZURE_OPENAI_ENDPOINT),
Expand All @@ -89,16 +88,16 @@ def create(
image_generation_driver = AzureOpenAiImageGenerationDriver(
azure_deployment="dall-e-3",
model="dall-e-3",
azure_endpoint=AZURE_OPENAI_ENDPOINT,
api_key=AZURE_OPENAI_API_KEY,
azure_endpoint=params["azure_endpoint"],
api_key=params["api_key"],
)

custom_config = AzureOpenAiDriversConfig(
prompt_driver=prompt_driver,
embedding_driver=embedding_driver,
image_generation_driver=image_generation_driver,
azure_endpoint=AZURE_OPENAI_ENDPOINT,
api_key=AZURE_OPENAI_API_KEY,
azure_endpoint=params["azure_endpoint"],
api_key=params["api_key"],
)

return (custom_config,)
7 changes: 7 additions & 0 deletions nodes/config/gtUIBaseConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def INPUT_TYPES(s):
# "stream": ([True, False], {"default": False}),
"env": ("ENV", {"default": None}),
"use_native_tools": ("BOOLEAN", {"default": True}),
"max_tokens": (
"INT",
{
"default": -1,
"tooltip": "Maximum tokens to generate. If <=0, it will use the default based on the tokenizer.",
},
),
},
}

Expand Down
Loading

0 comments on commit 898a7fa

Please sign in to comment.