Skip to content

Commit

Permalink
service needs to be wrapped in Mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
salman1993 committed Jan 10, 2025
1 parent 42f1429 commit f6b874c
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 78 deletions.
4 changes: 2 additions & 2 deletions crates/mcp-client/examples/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
EnvFilter::from_default_env().add_directive("mcp_client=debug".parse().unwrap()),
)
.init();

let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()]);
let handle1 = transport1.start().await?;
let service1 = McpService::with_timeout(handle1, Duration::from_secs(30));
Expand All @@ -30,7 +30,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let transport3 = SseTransport::new("http://localhost:8000/sse");
let handle3 = transport3.start().await?;
let service3 = McpService::with_timeout(handle3, Duration::from_secs(3));
let service3 = McpService::with_timeout(handle3, Duration::from_secs(10));
let client3 = McpClient::new(service3);

// Initialize both clients
Expand Down
4 changes: 3 additions & 1 deletion crates/mcp-client/examples/stdio_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::time::Duration;
// This example shows how to use the mcp-client crate to interact with a server that has a simple counter tool.
// The server is started by running `cargo run -p mcp-server` in the root of the mcp-server crate.
use anyhow::Result;
use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient};
use mcp_client::client::{
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
};
use mcp_client::transport::{StdioTransport, Transport};
use mcp_client::McpService;
use tracing_subscriber::EnvFilter;
Expand Down
54 changes: 0 additions & 54 deletions crates/mcp-client/examples/svc.rs

This file was deleted.

36 changes: 23 additions & 13 deletions crates/mcp-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::sync::atomic::{AtomicU64, Ordering};

use mcp_core::protocol::{
CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification,
JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult,
ServerCapabilities, METHOD_NOT_FOUND,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::atomic::{AtomicU64, Ordering};
use thiserror::Error;
use tower::Service;
use tokio::sync::Mutex;
use tower::{Service, ServiceExt}; // for Service::ready()

/// Error type for MCP client operations.
#[derive(Debug, Error)]
Expand All @@ -22,8 +22,8 @@ pub enum Error {
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),

#[error("Unexpected response from server")]
UnexpectedResponse,
#[error("Unexpected response from server: {0}")]
UnexpectedResponse(String),

#[error("Not initialized")]
NotInitialized,
Expand Down Expand Up @@ -85,7 +85,7 @@ where
S::Error: Into<Error>,
S::Future: Send,
{
service: S,
service: Mutex<S>,
next_id: AtomicU64,
server_capabilities: Option<ServerCapabilities>,
}
Expand All @@ -98,7 +98,7 @@ where
{
pub fn new(service: S) -> Self {
Self {
service,
service: Mutex::new(service),
next_id: AtomicU64::new(1),
server_capabilities: None,
}
Expand All @@ -109,6 +109,9 @@ where
where
R: for<'de> Deserialize<'de>,
{
let mut service = self.service.lock().await;
service.ready().await.map_err(|_| Error::NotReady)?;

let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let request = JsonRpcMessage::Request(JsonRpcRequest {
jsonrpc: "2.0".to_string(),
Expand All @@ -117,7 +120,6 @@ where
params: Some(params),
});

let mut service = self.service.clone();
let response_msg = service.call(request).await.map_err(Into::into)?;

match response_msg {
Expand All @@ -126,7 +128,9 @@ where
}) => {
// Verify id matches
if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
return Err(Error::UnexpectedResponse);
return Err(Error::UnexpectedResponse(
"id mismatch for JsonRpcResponse".to_string(),
));
}
if let Some(err) = error {
Err(Error::RpcError {
Expand All @@ -136,12 +140,14 @@ where
} else if let Some(r) = result {
Ok(serde_json::from_value(r)?)
} else {
Err(Error::UnexpectedResponse)
Err(Error::UnexpectedResponse("missing result".to_string()))
}
}
JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => {
if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
return Err(Error::UnexpectedResponse);
return Err(Error::UnexpectedResponse(
"id mismatch for JsonRpcError".to_string(),
));
}
Err(Error::RpcError {
code: error.code,
Expand All @@ -150,20 +156,24 @@ where
}
_ => {
// Requests/notifications not expected as a response
Err(Error::UnexpectedResponse)
Err(Error::UnexpectedResponse(
"unexpected message type".to_string(),
))
}
}
}

/// Send a JSON-RPC notification.
async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> {
let mut service = self.service.lock().await;
service.ready().await.map_err(|_| Error::NotReady)?;

let notification = JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: method.to_string(),
params: Some(params),
});

let mut service = self.service.clone();
service.call(notification).await.map_err(Into::into)?;
Ok(())
}
Expand Down
15 changes: 7 additions & 8 deletions crates/mcp-client/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
use futures::future::BoxFuture;
use mcp_core::protocol::JsonRpcMessage;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::{timeout::Timeout, Service, ServiceBuilder};

use crate::transport::{Error, TransportHandle};

/// A wrapper service that implements Tower's Service trait for MCP transport
#[derive(Clone)]
pub struct McpService<T> {
inner: T,
pub struct McpService<T: TransportHandle> {
inner: Arc<T>,
}

impl<T> McpService<T> {
impl<T: TransportHandle> McpService<T> {
pub fn new(transport: T) -> Self {
Self { inner: transport }
}

pub fn into_inner(self) -> T {
self.inner
Self {
inner: Arc::new(transport),
}
}
}

Expand Down

0 comments on commit f6b874c

Please sign in to comment.