Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add a redact agent #571

Open
wants to merge 1 commit into
base: v1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/goose/src/agents/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod agent;
mod capabilities;
mod default;
mod factory;
mod redact;
mod reference;
mod system;

Expand Down
207 changes: 207 additions & 0 deletions crates/goose/src/agents/redact.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/// A reference agent implementation that redacts redundant resource content
use async_trait::async_trait;
use futures::stream::BoxStream;
use std::collections::HashMap;
use tokio::sync::Mutex;
use tracing::{debug, instrument};

use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::system::{SystemConfig, SystemResult};
use crate::message::{Message, MessageContent, ToolRequest};
use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
use crate::register_agent;
use crate::token_counter::TokenCounter;
use mcp_core::content::Content;
use serde_json::Value;

/// Reference implementation of an Agent with resource redaction
pub struct RedactAgent {
capabilities: Mutex<Capabilities>,
_token_counter: TokenCounter,
}

impl RedactAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
Self {
capabilities: Mutex::new(Capabilities::new(provider)),
_token_counter: TokenCounter::new(),
}
}

/// Redact redundant resource content from messages
fn redact_redundant_resources(messages: &mut Vec<Message>) {
// Map to track the last occurrence of each resource URI
let mut uri_last_index: HashMap<String, usize> = HashMap::new();

// First pass: find all resource URIs and their last occurrence
Copy link
Collaborator

@salman1993 salman1993 Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think we can do this in one pass by iterating in reverse order

for (idx, message) in messages.iter().enumerate() {
if let Some(tool_response) = message.content.iter().find_map(|c| c.as_tool_response()) {
if let Ok(contents) = &tool_response.tool_result {
for content in contents {
if let Content::Resource(resource) = content {
if let Some(uri) = resource.get_uri() {
uri_last_index.insert(uri, idx);
}
}
}
}
}
}

// Second pass: redact content for resources that appear later
for (idx, message) in messages.iter_mut().enumerate() {
if let Some(tool_response) = message.content.iter_mut().find_map(|c| {
if let MessageContent::ToolResponse(tr) = c {
Some(tr)
} else {
None
}
}) {
if let Ok(contents) = tool_response.tool_result.as_mut() {
for content in contents.iter_mut() {
if let Content::Resource(resource) = content {
if let Some(uri) = resource.get_uri() {
if let Some(&last_idx) = uri_last_index.get(&uri) {
if last_idx > idx {
// This resource appears later, so redact its content
tracing::debug!(
message_index = idx,
resource_uri = uri,
"Redacting resource content that appears later at index {}",
last_idx
);
resource.set_text(format!("redacted for brevity - the content of {} is available below", uri));
}
}
}
}
}
}
}
}
}
}

#[async_trait]
impl Agent for RedactAgent {
async fn add_system(&mut self, system: SystemConfig) -> SystemResult<()> {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_system(system).await
}

async fn remove_system(&mut self, name: &str) {
let mut capabilities = self.capabilities.lock().await;
capabilities
.remove_system(name)
.await
.expect("Failed to remove system");
}

async fn list_systems(&self) -> Vec<String> {
let capabilities = self.capabilities.lock().await;
capabilities
.list_systems()
.await
.expect("Failed to list systems")
}

async fn passthrough(&self, _system: &str, _request: Value) -> SystemResult<Value> {
// TODO implement
Ok(Value::Null)
}

#[instrument(skip(self, messages), fields(user_message))]
async fn reply(
&self,
messages: &[Message],
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
let mut capabilities = self.capabilities.lock().await;
let tools = capabilities.get_prefixed_tools().await?;
let system_prompt = capabilities.get_system_prompt().await;
let _estimated_limit = capabilities
.provider()
.get_model_config()
.get_estimated_limit();

// Set the user_message field in the span instead of creating a new event
if let Some(content) = messages
.last()
.and_then(|msg| msg.content.first())
.and_then(|c| c.as_text())
{
debug!("user_message" = &content);
}

// Update conversation history for the start of the reply
let _resources = capabilities.get_resources().await?;

Ok(Box::pin(async_stream::try_stream! {
let _reply_guard = reply_span.enter();
loop {
// Get completion from provider
let (response, usage) = capabilities.provider().complete(
&system_prompt,
&messages,
&tools,
).await?;
capabilities.record_usage(usage).await;

// Yield the assistant's response
yield response.clone();

tokio::task::yield_now().await;

// First collect any tool requests
let tool_requests: Vec<&ToolRequest> = response.content
.iter()
.filter_map(|content| content.as_tool_request())
.collect();

if tool_requests.is_empty() {
break;
}

// Then dispatch each in parallel
let futures: Vec<_> = tool_requests
.iter()
.filter_map(|request| request.tool_call.clone().ok())
.map(|tool_call| capabilities.dispatch_tool_call(tool_call))
.collect();

// Process all the futures in parallel but wait until all are finished
let outputs = futures::future::join_all(futures).await;

// Create a message with the responses
let mut message_tool_response = Message::user();
// Now combine these into MessageContent::ToolResponse using the original ID
for (request, output) in tool_requests.iter().zip(outputs.into_iter()) {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
}

// Add new messages to history
messages.push(response);
messages.push(message_tool_response.clone());

// Redact redundant resources in the message history
Self::redact_redundant_resources(&mut messages);

// Yield the (potentially redacted) tool response
yield message_tool_response;
}
}))
}

async fn usage(&self) -> Vec<ProviderUsage> {
let capabilities = self.capabilities.lock().await;
capabilities.get_usage().await
}
}

register_agent!("redact", RedactAgent);
40 changes: 40 additions & 0 deletions crates/mcp-core/src/content.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,23 @@ impl EmbeddedResource {
_ => String::new(),
}
}

pub fn get_uri(&self) -> Option<String> {
match &self.resource {
ResourceContents::TextResourceContents { uri, .. } => Some(uri.clone()),
_ => None,
}
}

pub fn set_text(&mut self, new_text: String) -> bool {
match &mut self.resource {
ResourceContents::TextResourceContents { text, .. } => {
*text = new_text;
true
}
_ => false,
}
}
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
Expand Down Expand Up @@ -307,4 +324,27 @@ mod tests {
assert_eq!(content.audience(), Some(&vec![Role::User]));
assert_eq!(content.priority(), None);
}

#[test]
fn test_embedded_resource_methods() {
let content = Content::embedded_text("test.txt", "hello");
if let Content::Resource(resource) = content {
assert_eq!(resource.get_text(), "hello");
assert_eq!(resource.get_uri(), Some("test.txt".to_string()));
} else {
panic!("Expected Resource content");
}
}

#[test]
fn test_embedded_resource_set_text() {
let mut content = Content::embedded_text("test.txt", "hello");
if let Content::Resource(resource) = &mut content {
assert!(resource.set_text("world".to_string()));
assert_eq!(resource.get_text(), "world");
assert_eq!(resource.get_uri(), Some("test.txt".to_string()));
} else {
panic!("Expected Resource content");
}
}
}
Loading