Skip to content

Commit

Permalink
Revert "feat: read only active resources in the agent loop (#560)"
Browse files Browse the repository at this point in the history
This reverts commit 3e0ce19.
  • Loading branch information
salman1993 committed Jan 12, 2025
1 parent 03eb96e commit 5959f5e
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 137 deletions.
26 changes: 4 additions & 22 deletions crates/goose-mcp/src/developer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,24 +213,6 @@ impl DeveloperRouter {
}
}

// Helper method to mark a resource as active, and insert it into the active_resources map
fn add_active_resource(&self, uri: &str, resource: Resource) {
self.active_resources
.lock()
.unwrap()
.insert(uri.to_string(), resource.mark_active());
}

// Helper method to check if a resource is already an active one
// Tries to get the resource and then checks if it is active
fn is_active_resource(&self, uri: &str) -> bool {
self.active_resources
.lock()
.unwrap()
.get(uri)
.map_or(false, |r| r.is_active())
}

// Helper method to resolve a path relative to cwd
fn resolve_path(&self, path_str: &str) -> Result<PathBuf, ToolError> {
let cwd = self.cwd.lock().unwrap();
Expand Down Expand Up @@ -405,7 +387,7 @@ impl DeveloperRouter {
ToolError::ExecutionError(format!("Failed to create resource: {}", e))
})?;

self.add_active_resource(&uri, resource);
self.active_resources.lock().unwrap().insert(uri, resource);

let language = lang::get_language_identifier(path);
let formatted = formatdoc! {"
Expand Down Expand Up @@ -450,7 +432,7 @@ impl DeveloperRouter {
.to_string();

// Check if file already exists and is active
if path.exists() && !self.is_active_resource(&uri) {
if path.exists() && !self.active_resources.lock().unwrap().contains_key(&uri) {
return Err(ToolError::InvalidParameters(format!(
"File '{}' exists but is not active. View it first before overwriting.",
path.display()
Expand All @@ -468,7 +450,7 @@ impl DeveloperRouter {

let resource = Resource::new(uri.clone(), Some("text".to_string()), None)
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
self.add_active_resource(&uri, resource);
self.active_resources.lock().unwrap().insert(uri, resource);

// Try to detect the language from the file extension
let language = path.extension().and_then(|ext| ext.to_str()).unwrap_or("");
Expand Down Expand Up @@ -509,7 +491,7 @@ impl DeveloperRouter {
path.display()
)));
}
if !self.is_active_resource(&uri) {
if !self.active_resources.lock().unwrap().contains_key(&uri) {
return Err(ToolError::InvalidParameters(format!(
"You must view '{}' before editing it",
path.display()
Expand Down
70 changes: 10 additions & 60 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use chrono::{DateTime, TimeZone, Utc};
use rust_decimal_macros::dec;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::LazyLock;
use tokio::sync::Mutex;
use tracing::{debug, instrument};

Expand All @@ -11,12 +9,7 @@ use crate::prompt_template::load_prompt_file;
use crate::providers::base::{Provider, ProviderUsage};
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient};
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
use mcp_core::{Content, Tool, ToolCall, ToolError, ToolResult};

// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp
// This is to ensure that the resource is considered less important than resources with a more recent timestamp
static DEFAULT_TIMESTAMP: LazyLock<DateTime<Utc>> =
LazyLock::new(|| Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap());
use mcp_core::{Content, Resource, Tool, ToolCall, ToolError, ToolResult};

/// Manages MCP clients and their interactions
pub struct Capabilities {
Expand All @@ -26,39 +19,6 @@ pub struct Capabilities {
provider_usage: Mutex<Vec<ProviderUsage>>,
}

/// A flattened representation of a resource used by the agent to prepare inference
#[derive(Debug, Clone)]
pub struct ResourceItem {
pub client_name: String, // The name of the client that owns the resource
pub uri: String, // The URI of the resource
pub name: String, // The name of the resource
pub content: String, // The content of the resource
pub timestamp: DateTime<Utc>, // The timestamp of the resource
pub priority: f32, // The priority of the resource
pub token_count: Option<u32>, // The token count of the resource (filled in by the agent)
}

impl ResourceItem {
pub fn new(
client_name: String,
uri: String,
name: String,
content: String,
timestamp: DateTime<Utc>,
priority: f32,
) -> Self {
Self {
client_name,
uri,
name,
content,
timestamp,
priority,
token_count: None,
}
}
}

/// Sanitizes a string by replacing invalid characters with underscores.
/// Valid characters match [a-zA-Z0-9_-]
fn sanitize(input: String) -> String {
Expand Down Expand Up @@ -197,20 +157,17 @@ impl Capabilities {
}

/// Get client resources and their contents
pub async fn get_resources(&self) -> SystemResult<Vec<ResourceItem>> {
let mut result: Vec<ResourceItem> = Vec::new();

// TODO this data model needs flattening
pub async fn get_resources(
&self,
) -> SystemResult<HashMap<String, HashMap<String, (Resource, String)>>> {
let mut client_resource_content = HashMap::new();
for (name, client) in &self.clients {
let client_guard = client.lock().await;
let resources = client_guard.list_resources().await?;

let mut resource_content = HashMap::new();
for resource in resources.resources {
// Skip reading the resource if it's not marked active
// This avoids blowing up the context with inactive resources
if !resource.is_active() {
continue;
}

if let Ok(contents) = client_guard.read_resource(&resource.uri).await {
for content in contents.contents {
let (uri, content_str) = match content {
Expand All @@ -225,20 +182,13 @@ impl Capabilities {
..
} => (uri, blob),
};

result.push(ResourceItem::new(
name.clone(),
uri,
resource.name.clone(),
content_str,
resource.timestamp().unwrap_or(*DEFAULT_TIMESTAMP),
resource.priority().unwrap_or(0.0),
));
resource_content.insert(uri, (resource.clone(), content_str));
}
}
}
client_resource_content.insert(name.clone(), resource_content);
}
Ok(result)
Ok(client_resource_content)
}

/// Get the system prompt including client instructions
Expand Down
126 changes: 87 additions & 39 deletions crates/goose/src/agents/default.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde_json::json;
use std::collections::HashMap;
use tokio::sync::Mutex;
use tracing::{debug, instrument};

use super::Agent;
use crate::agents::capabilities::{Capabilities, ResourceItem};
use crate::agents::capabilities::Capabilities;
use crate::agents::system::{SystemConfig, SystemResult};
use crate::message::{Message, MessageContent, ToolRequest};
use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
use crate::register_agent;
use crate::token_counter::TokenCounter;
use mcp_core::{Content, Tool, ToolCall};
use mcp_core::{Content, Resource, Tool, ToolCall};
use serde_json::Value;
// used to sort resources by priority within error margin
const PRIORITY_EPSILON: f32 = 0.001;

/// Default implementation of an Agent
pub struct DefaultAgent {
Expand All @@ -38,13 +41,15 @@ impl DefaultAgent {
pending: &[Message],
target_limit: usize,
model_name: &str,
resource_items: &mut [ResourceItem],
resource_content: &HashMap<String, HashMap<String, (Resource, String)>>,
) -> SystemResult<Vec<Message>> {
// Flatten all resource content into a vector of strings
let resources: Vec<String> = resource_items
.iter()
.map(|item| item.content.clone())
.collect();
let mut resources = Vec::new();
for system_resources in resource_content.values() {
for (_, content) in system_resources.values() {
resources.push(content.clone());
}
}

let approx_count = self.token_counter.count_everything(
system_prompt,
Expand All @@ -58,41 +63,77 @@ impl DefaultAgent {
if approx_count > target_limit {
println!("[WARNING] Token budget exceeded. Current count: {} \n Difference: {} tokens over buget. Removing context", approx_count, approx_count - target_limit);

for item in resource_items.iter_mut() {
if item.token_count.is_none() {
let count = self
.token_counter
.count_tokens(&item.content, Some(model_name))
as u32;
item.token_count = Some(count);
// Get token counts for each resource
let mut system_token_counts = HashMap::new();

// Iterate through each system and its resources
for (system_name, resources) in resource_content {
let mut resource_counts = HashMap::new();
for (uri, (_resource, content)) in resources {
let token_count =
self.token_counter.count_tokens(content, Some(model_name)) as u32;
resource_counts.insert(uri.clone(), token_count);
}
system_token_counts.insert(system_name.clone(), resource_counts);
}

// Get all resource items, sort, then trim till we're under target limit
let mut trimmed_items: Vec<ResourceItem> = resource_items.to_vec();
// Sort resources by priority and timestamp and trim to fit context limit
let mut all_resources: Vec<(String, String, Resource, u32)> = Vec::new();
for (system_name, resources) in resource_content {
for (uri, (resource, _)) in resources {
if let Some(token_count) = system_token_counts
.get(system_name)
.and_then(|counts| counts.get(uri))
{
all_resources.push((
system_name.clone(),
uri.clone(),
resource.clone(),
*token_count,
));
}
}
}

// Sorts by timestamp (newest to oldest)
// Priority will be 1.0 for active resources so no need to compare
trimmed_items.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
// Sort by priority (high to low) and timestamp (newest to oldest)
all_resources.sort_by(|a, b| {
let a_priority = a.2.priority().unwrap_or(0.0);
let b_priority = b.2.priority().unwrap_or(0.0);
if (b_priority - a_priority).abs() < PRIORITY_EPSILON {
b.2.timestamp().cmp(&a.2.timestamp())
} else {
b.2.priority()
.partial_cmp(&a.2.priority())
.unwrap_or(std::cmp::Ordering::Equal)
}
});

// Remove resources until we're under target limit
let mut current_tokens = approx_count;
while current_tokens > target_limit && !trimmed_items.is_empty() {
let removed = trimmed_items.pop().unwrap();
// Subtract removed item’s token_count
if let Some(tc) = removed.token_count {
current_tokens = current_tokens.saturating_sub(tc as usize);

while current_tokens > target_limit && !all_resources.is_empty() {
if let Some((system_name, uri, _, token_count)) = all_resources.pop() {
if let Some(system_counts) = system_token_counts.get_mut(&system_name) {
system_counts.remove(&uri);
current_tokens -= token_count as usize;
}
}
}

// We removed some items, so let's use only the trimmed set for status
for item in &trimmed_items {
status_content.push(format!("{}\n```\n{}\n```\n", item.name, item.content));
// Create status messages only from resources that remain after token trimming
for (system_name, uri, _, _) in &all_resources {
if let Some(system_resources) = resource_content.get(system_name) {
if let Some((resource, content)) = system_resources.get(uri) {
status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content));
}
}
}
} else {
// Create status messages from all resources when no trimming needed
for item in resource_items {
status_content.push(format!("{}\n```\n{}\n```\n", item.name, item.content));
for resources in resource_content.values() {
for (resource, content) in resources.values() {
status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content));
}
}
}

Expand All @@ -107,15 +148,17 @@ impl DefaultAgent {
new_messages.push(msg.clone());
}

// Finally add the status messages
let message_use =
Message::assistant().with_tool_request("000", Ok(ToolCall::new("status", json!({}))));
// Finally add the status messages, if we have any
if !status_str.is_empty() {
let message_use = Message::assistant()
.with_tool_request("000", Ok(ToolCall::new("status", json!({}))));

let message_result =
Message::user().with_tool_response("000", Ok(vec![Content::text(status_str)]));
let message_result =
Message::user().with_tool_response("000", Ok(vec![Content::text(status_str)]));

new_messages.push(message_use);
new_messages.push(message_result);
new_messages.push(message_use);
new_messages.push(message_result);
}

Ok(new_messages)
}
Expand Down Expand Up @@ -173,15 +216,20 @@ impl Agent for DefaultAgent {
}

// Update conversation history for the start of the reply
let resources = capabilities.get_resources().await?;
let mut messages = self
.prepare_inference(
&system_prompt,
&tools,
messages,
&Vec::new(),
estimated_limit,
&capabilities.provider().get_model_config().model_name,
&mut capabilities.get_resources().await?,
&capabilities
.provider()
.get_model_config()
.model_name
.clone(),
&resources,
)
.await?;

Expand Down Expand Up @@ -246,7 +294,7 @@ impl Agent for DefaultAgent {


let pending = vec![response, message_tool_response];
messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit, &capabilities.provider().get_model_config().model_name, &mut capabilities.get_resources().await?).await?;
messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit, &capabilities.provider().get_model_config().model_name, &capabilities.get_resources().await?).await?;
}
}))
}
Expand Down
Loading

0 comments on commit 5959f5e

Please sign in to comment.