From 5959f5e178d25a6f09a9b28f6366292773c49299 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Sun, 12 Jan 2025 18:02:35 -0500 Subject: [PATCH] Revert "feat: read only active resources in the agent loop (#560)" This reverts commit 3e0ce19ddeecb6971044bf6da11a2b4b4ffa3c28. --- crates/goose-mcp/src/developer/mod.rs | 26 +---- crates/goose/src/agents/capabilities.rs | 70 ++----------- crates/goose/src/agents/default.rs | 126 ++++++++++++++++-------- crates/mcp-core/src/resource.rs | 16 --- 4 files changed, 101 insertions(+), 137 deletions(-) diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index c50fc32d3..0fe51fbb9 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -213,24 +213,6 @@ impl DeveloperRouter { } } - // Helper method to mark a resource as active, and insert it into the active_resources map - fn add_active_resource(&self, uri: &str, resource: Resource) { - self.active_resources - .lock() - .unwrap() - .insert(uri.to_string(), resource.mark_active()); - } - - // Helper method to check if a resource is already an active one - // Tries to get the resource and then checks if it is active - fn is_active_resource(&self, uri: &str) -> bool { - self.active_resources - .lock() - .unwrap() - .get(uri) - .map_or(false, |r| r.is_active()) - } - // Helper method to resolve a path relative to cwd fn resolve_path(&self, path_str: &str) -> Result { let cwd = self.cwd.lock().unwrap(); @@ -405,7 +387,7 @@ impl DeveloperRouter { ToolError::ExecutionError(format!("Failed to create resource: {}", e)) })?; - self.add_active_resource(&uri, resource); + self.active_resources.lock().unwrap().insert(uri, resource); let language = lang::get_language_identifier(path); let formatted = formatdoc! {" @@ -450,7 +432,7 @@ impl DeveloperRouter { .to_string(); // Check if file already exists and is active - if path.exists() && !self.is_active_resource(&uri) { + if path.exists() && !self.active_resources.lock().unwrap().contains_key(&uri) { return Err(ToolError::InvalidParameters(format!( "File '{}' exists but is not active. View it first before overwriting.", path.display() @@ -468,7 +450,7 @@ impl DeveloperRouter { let resource = Resource::new(uri.clone(), Some("text".to_string()), None) .map_err(|e| ToolError::ExecutionError(e.to_string()))?; - self.add_active_resource(&uri, resource); + self.active_resources.lock().unwrap().insert(uri, resource); // Try to detect the language from the file extension let language = path.extension().and_then(|ext| ext.to_str()).unwrap_or(""); @@ -509,7 +491,7 @@ impl DeveloperRouter { path.display() ))); } - if !self.is_active_resource(&uri) { + if !self.active_resources.lock().unwrap().contains_key(&uri) { return Err(ToolError::InvalidParameters(format!( "You must view '{}' before editing it", path.display() diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index 4007896ac..9e625e00d 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -1,8 +1,6 @@ -use chrono::{DateTime, TimeZone, Utc}; use rust_decimal_macros::dec; use std::collections::HashMap; use std::sync::Arc; -use std::sync::LazyLock; use tokio::sync::Mutex; use tracing::{debug, instrument}; @@ -11,12 +9,7 @@ use crate::prompt_template::load_prompt_file; use crate::providers::base::{Provider, ProviderUsage}; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient}; use mcp_client::transport::{SseTransport, StdioTransport, Transport}; -use mcp_core::{Content, Tool, ToolCall, ToolError, ToolResult}; - -// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp -// This is to ensure that the resource is considered less important than resources with a more recent timestamp -static DEFAULT_TIMESTAMP: LazyLock> = - LazyLock::new(|| Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap()); +use mcp_core::{Content, Resource, Tool, ToolCall, ToolError, ToolResult}; /// Manages MCP clients and their interactions pub struct Capabilities { @@ -26,39 +19,6 @@ pub struct Capabilities { provider_usage: Mutex>, } -/// A flattened representation of a resource used by the agent to prepare inference -#[derive(Debug, Clone)] -pub struct ResourceItem { - pub client_name: String, // The name of the client that owns the resource - pub uri: String, // The URI of the resource - pub name: String, // The name of the resource - pub content: String, // The content of the resource - pub timestamp: DateTime, // The timestamp of the resource - pub priority: f32, // The priority of the resource - pub token_count: Option, // The token count of the resource (filled in by the agent) -} - -impl ResourceItem { - pub fn new( - client_name: String, - uri: String, - name: String, - content: String, - timestamp: DateTime, - priority: f32, - ) -> Self { - Self { - client_name, - uri, - name, - content, - timestamp, - priority, - token_count: None, - } - } -} - /// Sanitizes a string by replacing invalid characters with underscores. /// Valid characters match [a-zA-Z0-9_-] fn sanitize(input: String) -> String { @@ -197,20 +157,17 @@ impl Capabilities { } /// Get client resources and their contents - pub async fn get_resources(&self) -> SystemResult> { - let mut result: Vec = Vec::new(); - + // TODO this data model needs flattening + pub async fn get_resources( + &self, + ) -> SystemResult>> { + let mut client_resource_content = HashMap::new(); for (name, client) in &self.clients { let client_guard = client.lock().await; let resources = client_guard.list_resources().await?; + let mut resource_content = HashMap::new(); for resource in resources.resources { - // Skip reading the resource if it's not marked active - // This avoids blowing up the context with inactive resources - if !resource.is_active() { - continue; - } - if let Ok(contents) = client_guard.read_resource(&resource.uri).await { for content in contents.contents { let (uri, content_str) = match content { @@ -225,20 +182,13 @@ impl Capabilities { .. } => (uri, blob), }; - - result.push(ResourceItem::new( - name.clone(), - uri, - resource.name.clone(), - content_str, - resource.timestamp().unwrap_or(*DEFAULT_TIMESTAMP), - resource.priority().unwrap_or(0.0), - )); + resource_content.insert(uri, (resource.clone(), content_str)); } } } + client_resource_content.insert(name.clone(), resource_content); } - Ok(result) + Ok(client_resource_content) } /// Get the system prompt including client instructions diff --git a/crates/goose/src/agents/default.rs b/crates/goose/src/agents/default.rs index 1ecabeabd..9b6916a3b 100644 --- a/crates/goose/src/agents/default.rs +++ b/crates/goose/src/agents/default.rs @@ -1,19 +1,22 @@ use async_trait::async_trait; use futures::stream::BoxStream; use serde_json::json; +use std::collections::HashMap; use tokio::sync::Mutex; use tracing::{debug, instrument}; use super::Agent; -use crate::agents::capabilities::{Capabilities, ResourceItem}; +use crate::agents::capabilities::Capabilities; use crate::agents::system::{SystemConfig, SystemResult}; use crate::message::{Message, MessageContent, ToolRequest}; use crate::providers::base::Provider; use crate::providers::base::ProviderUsage; use crate::register_agent; use crate::token_counter::TokenCounter; -use mcp_core::{Content, Tool, ToolCall}; +use mcp_core::{Content, Resource, Tool, ToolCall}; use serde_json::Value; +// used to sort resources by priority within error margin +const PRIORITY_EPSILON: f32 = 0.001; /// Default implementation of an Agent pub struct DefaultAgent { @@ -38,13 +41,15 @@ impl DefaultAgent { pending: &[Message], target_limit: usize, model_name: &str, - resource_items: &mut [ResourceItem], + resource_content: &HashMap>, ) -> SystemResult> { // Flatten all resource content into a vector of strings - let resources: Vec = resource_items - .iter() - .map(|item| item.content.clone()) - .collect(); + let mut resources = Vec::new(); + for system_resources in resource_content.values() { + for (_, content) in system_resources.values() { + resources.push(content.clone()); + } + } let approx_count = self.token_counter.count_everything( system_prompt, @@ -58,41 +63,77 @@ impl DefaultAgent { if approx_count > target_limit { println!("[WARNING] Token budget exceeded. Current count: {} \n Difference: {} tokens over buget. Removing context", approx_count, approx_count - target_limit); - for item in resource_items.iter_mut() { - if item.token_count.is_none() { - let count = self - .token_counter - .count_tokens(&item.content, Some(model_name)) - as u32; - item.token_count = Some(count); + // Get token counts for each resource + let mut system_token_counts = HashMap::new(); + + // Iterate through each system and its resources + for (system_name, resources) in resource_content { + let mut resource_counts = HashMap::new(); + for (uri, (_resource, content)) in resources { + let token_count = + self.token_counter.count_tokens(content, Some(model_name)) as u32; + resource_counts.insert(uri.clone(), token_count); } + system_token_counts.insert(system_name.clone(), resource_counts); } - // Get all resource items, sort, then trim till we're under target limit - let mut trimmed_items: Vec = resource_items.to_vec(); + // Sort resources by priority and timestamp and trim to fit context limit + let mut all_resources: Vec<(String, String, Resource, u32)> = Vec::new(); + for (system_name, resources) in resource_content { + for (uri, (resource, _)) in resources { + if let Some(token_count) = system_token_counts + .get(system_name) + .and_then(|counts| counts.get(uri)) + { + all_resources.push(( + system_name.clone(), + uri.clone(), + resource.clone(), + *token_count, + )); + } + } + } - // Sorts by timestamp (newest to oldest) - // Priority will be 1.0 for active resources so no need to compare - trimmed_items.sort_by(|a, b| b.timestamp.cmp(&a.timestamp)); + // Sort by priority (high to low) and timestamp (newest to oldest) + all_resources.sort_by(|a, b| { + let a_priority = a.2.priority().unwrap_or(0.0); + let b_priority = b.2.priority().unwrap_or(0.0); + if (b_priority - a_priority).abs() < PRIORITY_EPSILON { + b.2.timestamp().cmp(&a.2.timestamp()) + } else { + b.2.priority() + .partial_cmp(&a.2.priority()) + .unwrap_or(std::cmp::Ordering::Equal) + } + }); // Remove resources until we're under target limit let mut current_tokens = approx_count; - while current_tokens > target_limit && !trimmed_items.is_empty() { - let removed = trimmed_items.pop().unwrap(); - // Subtract removed item’s token_count - if let Some(tc) = removed.token_count { - current_tokens = current_tokens.saturating_sub(tc as usize); + + while current_tokens > target_limit && !all_resources.is_empty() { + if let Some((system_name, uri, _, token_count)) = all_resources.pop() { + if let Some(system_counts) = system_token_counts.get_mut(&system_name) { + system_counts.remove(&uri); + current_tokens -= token_count as usize; + } } } - // We removed some items, so let's use only the trimmed set for status - for item in &trimmed_items { - status_content.push(format!("{}\n```\n{}\n```\n", item.name, item.content)); + // Create status messages only from resources that remain after token trimming + for (system_name, uri, _, _) in &all_resources { + if let Some(system_resources) = resource_content.get(system_name) { + if let Some((resource, content)) = system_resources.get(uri) { + status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); + } + } } } else { // Create status messages from all resources when no trimming needed - for item in resource_items { - status_content.push(format!("{}\n```\n{}\n```\n", item.name, item.content)); + for resources in resource_content.values() { + for (resource, content) in resources.values() { + status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); + } } } @@ -107,15 +148,17 @@ impl DefaultAgent { new_messages.push(msg.clone()); } - // Finally add the status messages - let message_use = - Message::assistant().with_tool_request("000", Ok(ToolCall::new("status", json!({})))); + // Finally add the status messages, if we have any + if !status_str.is_empty() { + let message_use = Message::assistant() + .with_tool_request("000", Ok(ToolCall::new("status", json!({})))); - let message_result = - Message::user().with_tool_response("000", Ok(vec![Content::text(status_str)])); + let message_result = + Message::user().with_tool_response("000", Ok(vec![Content::text(status_str)])); - new_messages.push(message_use); - new_messages.push(message_result); + new_messages.push(message_use); + new_messages.push(message_result); + } Ok(new_messages) } @@ -173,6 +216,7 @@ impl Agent for DefaultAgent { } // Update conversation history for the start of the reply + let resources = capabilities.get_resources().await?; let mut messages = self .prepare_inference( &system_prompt, @@ -180,8 +224,12 @@ impl Agent for DefaultAgent { messages, &Vec::new(), estimated_limit, - &capabilities.provider().get_model_config().model_name, - &mut capabilities.get_resources().await?, + &capabilities + .provider() + .get_model_config() + .model_name + .clone(), + &resources, ) .await?; @@ -246,7 +294,7 @@ impl Agent for DefaultAgent { let pending = vec![response, message_tool_response]; - messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit, &capabilities.provider().get_model_config().model_name, &mut capabilities.get_resources().await?).await?; + messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit, &capabilities.provider().get_model_config().model_name, &capabilities.get_resources().await?).await?; } })) } diff --git a/crates/mcp-core/src/resource.rs b/crates/mcp-core/src/resource.rs index 79518f9f8..19a44d153 100644 --- a/crates/mcp-core/src/resource.rs +++ b/crates/mcp-core/src/resource.rs @@ -6,8 +6,6 @@ use url::Url; use crate::content::Annotations; -const EPSILON: f32 = 1e-6; // Tolerance for floating point comparison - /// Represents a resource in the system with metadata #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] @@ -119,20 +117,6 @@ impl Resource { self } - /// Mark the resource as active, i.e. set its priority to 1.0 - pub fn mark_active(self) -> Self { - self.with_priority(1.0) - } - - // Check if the resource is active - pub fn is_active(&self) -> bool { - if let Some(priority) = self.priority() { - (priority - 1.0).abs() < EPSILON - } else { - false - } - } - /// Returns the priority of the resource, if set pub fn priority(&self) -> Option { self.annotations.as_ref().and_then(|a| a.priority)